【Finetune】(一)、transformers之BitFit微调

文章目录

  • 0、参数微调简介
  • 1、常见的微调方法
  • 2、代码实战
    • 2.1、导包
    • 2.2、加载数据集
    • 2.3、数据集处理
    • 2.4、创建模型
    • 2.5、BitFit微调*
    • 2.6、配置模型参数
    • 2.7、创建训练器
    • 2.8、模型训练
    • 2.9、模型推理

0、参数微调简介

 参数微调方法是仅对模型的一小部分的参数(这一小部分可能是模型自身的,也可能是外部引入的)进行训练,便可以为模型带来显著的性能变化,在一些场景下甚至不输于全量微调。
 由于训练一小部分参数,极大程度降低了训练大模型的算力需求,不需要多机多卡,单卡就可以完成对一些大模型的训练。不仅如此,少量的训练参数,对存储的要求同样降低很多,大多数的参数微调方法只需要保存训练部分的参数,与动辄几十GB的原始大模型相比,几乎可以忽略。

1、常见的微调方法

 常见的微调方法如图所示:
在这里插入图片描述

Lialin, Vladislav, Vijeta Deshpande, and Anna Rumshisky. “Scaling down to scale up: A guide to parameter-efficient fine-tuning.” arXiv preprint arXiv:2303.15647 (2023).

2、代码实战

  • 模型——bloom-389m-zh
  • 数据集——alpaca_data_zh

2.1、导包

from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer

2.2、加载数据集

ds = Dataset.load_from_disk("./alpaca_data_zh/")

2.3、数据集处理

tokenizer = AutoTokenizer.from_pretrained("../Model/bloom-389m-zh")
tokenizer
def process_func(example):
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
    response = tokenizer(example["output"] + tokenizer.eos_token)
    input_ids = instruction["input_ids"] + response["input_ids"]
    attention_mask = instruction["attention_mask"] + response["attention_mask"]
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
tokenized_ds

2.4、创建模型

model = AutoModelForCausalLM.from_pretrained("../Model/bloom-389m-zh",low_cpu_mem_usage=True)

2.5、BitFit微调*

#选择模型参数里面的所有bias部分
#非bias部分冻结
num_param = 0
for name,param in model.named_parameters():
    if 'bias' not in name:
        param.requires_grad = False
    else:
        num_param+=param.numel()
num_param

2.6、配置模型参数

args = TrainingArguments(
    output_dir="./chatbot",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    logging_steps=10,
    num_train_epochs=1
)

2.7、创建训练器

trainer = Trainer(
    args=args,
    model=model,
    train_dataset=tokenized_ds,
    data_collator=DataCollatorForSeq2Seq(tokenizer, padding=True, )
)

2.8、模型训练

trainer.train()

2.9、模型推理

from transformers import pipeline

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
ipt = "Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: "
pipe(ipt, max_length=256, do_sample=True, temperature=0.5)

http://www.niftyadmin.cn/n/5665971.html

相关文章

Windows本地制作java证书(与jeecgboot配置本地证书ssl问题)

1:JDK生成自签证书SSL,首先以管理员身份运行CMD窗口,执行命令 keytool -genkey -alias testhttps -keyalg RSA -keysize 2048 -validity 36500 -keystore "F:/ssl/testhttps.keystore"F:\ssl>keytool -genkey -alias testhttps -keyalg R…

软件测试技术之 GPU 单元测试是什么!

1 背景 测试是开发的一个非常重要的方面,可以在很大程度上决定一个应用程序的命运。良好的测试可以在早期捕获导致应用程序崩溃的问题,但较差的测试往往总是导致故障和停机。 单元测试用于测试各个代码组件,并确保代码按照预期的方式工作。单…

Qt安卓开发连接手机调试(红米K60为例)

1.前置条件 本人默认您已经完成Qt安卓环境的配置,若还没配置请参考链接文章:【Qt】最详细教程,如何从零配置Qt Android安卓环境_qt_七夕先生-开放原子开发者工作坊。准备一台目前主流在用的手机,其实自己用的就行(只要你不是某些…

Kubernetes1.24版本以上集群部署 初始化init报错:[kubelet-check] Initial timeout of 40s passed.

描述: 在安装Kubernetes1.28.2,初始化init时出现问题: [wait-control-plane] Waiting for the kubelet to boot up the control plane as static Pods from directory "/etc/kubernetes/manifests". This can take up to 4m0s [k…

活动系统开发之采用设计模式与非设计模式的区别-后台功能总结

1、数据库ER图 2、后台功能字段 题目功能字段 数据列表 编号题目名称选项数量状态 1启用0禁用创建时间修改时间保存 题目名称选项集 选项内容是否正确答案 1正确0错误启禁用删除素材图库功能字段 数据列表 编号原文件名称文件类型文件大小加密后文件名文件具体路径上传类型状态…

深度学习速通系列:TextCNN介绍

TextCNN是一种用于文本分类的卷积神经网络模型,由Yoon Kim在2014年的论文《Convolutional Neural Networks for Sentence Classification》中提出。它将卷积神经网络(CNN)应用于文本数据,通过使用不同大小的卷积核来提取文本中的局…

八股文-HashMap

是什么?谁发明的?用来做什么?特点是什么? 哈希表,JDK自带的存储容器,存储key-value数据,特点是访问快 为啥访问快?底层结构?原理? 底层采用数组链表/红黑树…

【制作100个unity游戏之32】unity开发属于自己的一个2d/3d桌面宠物,可以实时计算已经获取的工资

最终效果 文章目录 最终效果一、实现Windows消息弹窗二、将窗口扩展到工作区三、穿透能点击到其他区域四、模型交互1、我们可以新增ObjectDrag 代码控制人物拖拖动2、实现模型交互五、最终代码六、其他七、游玩地址使用Live2D实现桌宠七、源码参考完结一、实现Windows消息弹窗 …