[THUDM/ChatGLM-6B][BUG/Help] <请教下提供的量化脚本和hf的load_in_8bit有什么区别>

2024-05-10 882 views
4

请教下提供的量化脚本和hf的load_in_8bit有什么区别

Environment
- OS: Linux
- Python: 3.9
- Transformers: 4.29.1
- PyTorch: 2.0
- CUDA Support (`python -c "import torch; print(torch.cuda.is_available())"`) : True

回答

7

链接打不开啊,还有就是我用官方的int8+lora会提示有问题,但是我看用load_in_8bit+lora就可以训,知道这是为啥吗?

8

@Tongjilibo 我验证下来,无论是lora还是qlora都是load官方全精度的模型,最后训练完后,把adapter和base model合并,再生成fp16,int8,int4的模型。 都放在这了:https://github.com/shuxueslpi/chatGLM-6B-QLoRA

9

@Tongjilibo 目前来看时因为官方的量化方式和huggingface load_in_8bit的量化方式不同,如果采用官方的量化时,全精度模型的linear层会被替换为新的linear量化层,这个时候用huggingface提供的lora层对官方量化的liear层进行替换时,是无法识别的。 具体来说,加载官方的8bit模型时,‘query_key_value’这个层会被替换为‘QuantizedLinear’,这个时候如果再用hugginface提供的lora将这个层替换为lora层时,huggingface无法将这个层识别为量化层,会将它替换为常规的linear层,并且将QuantizedLinear经过量化的权重数据赋值给这个新的层,这个时候前向传播的时候到了lora层,就会报错,因为没法用量化的权重数据和输入的浮点数进行运算。peft库lora源码的660行左右。输入x和self.weight的类型不同

def forward(self, x: torch.Tensor):
        previous_dtype = x.dtype
        if self.active_adapter not in self.lora_A.keys():
            return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
        if self.disable_adapters:
            if self.r[self.active_adapter] > 0 and self.merged:
                self.unmerge()
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
        elif self.r[self.active_adapter] > 0 and not self.merged:
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

            x = x.to(self.lora_A[self.active_adapter].weight.dtype)

            result += (
                self.lora_B[self.active_adapter](
                    self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
                )
                * self.scaling[self.active_adapter]
            )
        else:
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

        result = result.to(previous_dtype)

        return result
0

@RuSignalFlag 好的,谢谢,我看了下peft里面是这样的逻辑,所以如果使用lora,比较方便的方式还是用load_in_8bit来做