对 ChatGLM-6B 做 LoRA fine tuning
hatGLM-6B 是一个支持中英双语的对话语言模型,基于 GLM (General Language Model)。它只有 62 亿个参数,量化后最低 (INT4 量化) 只需要 6GB 的显存,完全可以部署到消费级显卡上。在实际使用这个模型一段时间以后,我们发现模型的对话表现能力确实非常不错。那么,基于这个模型做 Fine-tuning 就非常有价值了。
声明:
本文提供的所有技术信息,都基于 THUDM/chatglm-6b 的历史版本: 096f3de6b4959ce38bef7bb05f3129c931a3084e
。
源码地址:
搭建依赖环境
安装 PyTorch 环境:
|
按照 ChatGLM-6B 的官方指导,安装软件依赖环境:
|
为了做 LoRA,还要安装 peft
|
加载模型和 Tokenizer
|
正如前面声明所述,本文使用的历史版本号是 096f3de6b4959ce38bef7bb05f3129c931a3084e
。如果开发者需要其他版本号,只需要更改 revision
值,并重新训练即可。
分析模型结构
模型加载完后,我们可以打印这个 model
和 tokenizer
,建立对模型的基本认知。
首先打印model
:
|
得到如下结果:
|
简单分析这个模型结构,至少可以得到如下一些信息:
- 模型使用了 Transformer 结构,因此可以使用 LoRA 进行 Fine-tuning
- 从 Word Embedding 层可以看出,词汇表大小是
150528
- LoRA 可以操作的目标是:
query_key_value
再打印tokenizer
:
|
得到如下结果(为了便于阅读,已对结果做了分行处理):
|
这里有几个可以关注的点:
- 词汇表大小
vocab_size
是150344
- 不是一个 fast Tokenizer(
is_fast
的值是False
) - 特殊 token 包括:
bos
eos
pad
和mask
为什么 model 中的词汇表大小是 150528
,而 tokenizer
中定义的词汇表大小却是 150344
呢?读者可以带着这个疑问去读一读模型项目的源码,看看能不能找到答案。
配置 LoRA
借助 peft 库,我们可以很方便地对模型注入 LoRA。
|
打印可训练的参数量:
|
得到如下结果:
|
可以看到,总的参数量是 6,258,876,416
,可训练的参数量是 3,670,016
,占比 0.0586%
左右。训练参数量只是百万级别的,可谓相当友好了!另外需要注意的一点是,ChatGLM-6B 是一个因果语言模型 (Causal Language Model),因此我们这里选择的任务类型是 CAUSAL_LM
。
构建数据集
定义常量
构建之前,我们先定义几个特殊 Token 常量:
|
将这几个值打印出来:
|
得到如下结果:
|
我们也可以直接用这个常量结果替换动态计算的部分。常量修改后的结果变成:
|
除了上面定义的 Token 常量,我们还需要定义模型训练绑定的设备名,以及最大输入长度和最大输出长度等,如下:
|
开发者可以结合自己的显卡性能和要处理的数据集特点来确定这些最大长度。
测试 Tokenizer 的编解码
我们可以先做个简单的测试:
|
输出结果是:
|
从这个结果可以看出,“AI探险家”这几个字的裸编码是 [26738, 98715, 83920]
。为什么是这样呢?我们可以对每一个数值再解码,看看输出结果:
|
输出结果是:
|
观察这个结果,读者应该能对词汇表建立基本的认知了。读者如果有兴趣,还可以分别针对 “A” “I” “探” “险” 这几个字分别编码,看看编码结果是什么。另外,当 add_special_tokens = True
时,编码结果会在末尾添加 150001
和 150004
,也就是 gmask
和 bos
。请注意,我们的训练数据,要按照如下编码要求进行构造:
|
因此,前半部分文本的编码可以直接让 add_special_tokens = True
,后半部分文本的编码则让 add_special_tokens = False
,最后再拼接一个 eop
。
定义 Prompt
我们 Fine-tuning 的任务是问答任务(简称 QA),因此一个简单的 Prompt 是这样的:
|
{}
里填入 QA 训练集的问题文本。在显存有限的情况下,如果不对长文本做限制处理,很容易出现类似 CUDA out of memory
这样的报错。处理长文本,在给定编码后的数组上限时,可能存在这么几种方式:
- 截断末尾超出部分的编码
- 截断前面超出部分的编码
- 丢掉训练样本
每一种方式都有各自的优劣,开发者可以根据自身数据的特点自行选择一种处理方式。当然,如果你的显存够大,也可以不处理。本文以上述第一种方式进行处理。 为了不把 PROMPT_PATTERN
中的 \n答:
这几个字截断掉,我们将整个 PROMPT_PATTERN
拆成两部分:
|
基于这份 Prompt 模板,我们定义下面三个辅助方法:
|
值得注意的两点:
- 从
create_prompt_ids
这个函数实现可以看出,我们编码分隔符SEP_PATTERN
时自动添加了前面所述的 2 个特殊 Token。 - 对
create_inputs_and_labels
的函数实现中,我们将labels
无需处理的部分用数值-100
来表示。因为ChatGLMForConditionalGeneration
内部在计算损失函数的时候,用的是torch.nn.CrossEntropyLoss
。该函数的参数之一ignore_index
默认值是-100
。这就让我们在计算损失函数时,无需考虑非标识部分的数值。
构建 Attention Mask 和 Position IDs
|
在这个通用实现中,我们针对 mask
和 gmask
两种情况做了区分,同时也对是否执行 position_encoding_2d
分情况处理。本文的 QA 任务采用的是 gmask
,并且使用 position_encoding_2d = True
。
我们可以构建下面的问答,来验证下这几个函数的输出:
|
输出结果(为了便于阅读,已对输出进行格式化操作):
|
结合论文观察数据,基本符合预期。
创建数据集
我们先定义具有如下格式的训练数据:
|
定义好格式后,我们先创建一个 QADataset
类,如下:
|
然后创建一个 Data Collator:
|
开始训练
|
预测
|
保存训练模型
|
重载训练后的模型
|
文章来源:https://aizpy.com/2023/03/30/chatglm-6b-lora/
布施恩德可便相知重
微信扫一扫打赏
支付宝扫一扫打赏