[PaddlePaddle/PaddleOCR]多字体训练,如何解决不懂字体间相似字符造成的干扰?

2024-05-14 656 views
5

请提供下述完整信息以便快速定位问题/Please provide the following information to quickly locate the problem

  • 系统环境/System Environment:win10
  • 版本号/Version:Paddle: PaddleOCR:2.5 问题相关组件/Related components:
  • 运行指令/Command Code:
  • 完整报错/Complete Error Message:

多字体训练,如何解决不懂字体间相似字符造成的干扰?

回答

7

根据不同字体生成一批数据,微调

4

同一批数据包含不同字体,会造成干扰吧?

5

使用 combined loss 可以有效降低相似字符问题,具体可搜索官方文档。

3

具体要怎么设置?目前用的是默认的 Loss: name: CTCLoss

知识蒸馏任务中,损失函数配置如下所示。

Loss:
  name: CombinedLoss
  loss_config_list:
  - DistillationDMLLoss:                       # 蒸馏的DML损失函数,继承自标准的DMLLoss
      weight: 1.0                              # 权重
      act: "softmax"                           # 激活函数,对输入使用激活函数处理,可以为softmax, sigmoid或者为None,默认为None
      use_log: true                            # 对输入计算log,如果函数已经
      model_name_pairs:                        # 用于计算DML loss的子网络名称对,如果希望计算其他子网络的DML loss,可以在列表下面继续填充
      - ["Student", "Teacher"]
      key: head_out                            # 取子网络输出dict中,该key对应的tensor
      multi_head: True                         # 是否为多头结构
      dis_head: ctc                            # 指定用于计算损失函数的head
      name: dml_ctc                            # 蒸馏loss的前缀名称,避免不同loss之间的命名冲突
  - DistillationDMLLoss:                       # 蒸馏的DML损失函数,继承自标准的DMLLoss
      weight: 0.5                              # 权重
      act: "softmax"                           # 激活函数,对输入使用激活函数处理,可以为softmax, sigmoid或者为None,默认为None
      use_log: true                            # 对输入计算log,如果函数已经
      model_name_pairs:                        # 用于计算DML loss的子网络名称对,如果希望计算其他子网络的DML loss,可以在列表下面继续填充
      - ["Student", "Teacher"]
      key: head_out                            # 取子网络输出dict中,该key对应的tensor
      multi_head: True                         # 是否为多头结构
      dis_head: sar                            # 指定用于计算损失函数的head
      name: dml_sar                            # 蒸馏loss的前缀名称,避免不同loss之间的命名冲突
  - DistillationDistanceLoss:                  # 蒸馏的距离损失函数
      weight: 1.0                              # 权重
      mode: "l2"                               # 距离计算方法,目前支持l1, l2, smooth_l1
      model_name_pairs:                        # 用于计算distance loss的子网络名称对
      - ["Student", "Teacher"]
      key: backbone_out                        # 取子网络输出dict中,该key对应的tensor
  - DistillationCTCLoss:                       # 基于蒸馏的CTC损失函数,继承自标准的CTC loss
      weight: 1.0                              # 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段
      model_name_list: ["Student", "Teacher"]  # 对于蒸馏模型的预测结果,提取这两个子网络的输出,与gt计算CTC loss
      key: head_out                            # 取子网络输出dict中,该key对应的tensor
  - DistillationSARLoss:                       # 基于蒸馏的SAR损失函数,继承自标准的SARLoss
      weight: 1.0                              # 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段
      model_name_list: ["Student", "Teacher"]  # 对于蒸馏模型的预测结果,提取这两个子网络的输出,与gt计算CTC loss
      key: head_out                            # 取子网络输出dict中,该key对应的tensor
      multi_head: True                         # 是否为多头结构,为true时,取出其中的SAR分支计算损失函数
9

使用 combined loss 可以有效降低相似字符问题,具体可搜索官方文档。

导出的center loss文件train_center.pkl怎么使用?加入yml文件报错了。

Loss: name: CTCLoss

Loss: name: CombinedLoss loss_config_list:

  • CTCLoss: use_focal_loss: false weight: 1.0
  • CenterLoss: weight: 0.05 num_classes: 63466 feat_dim: 512 center_file_path: C:/F/pycharm2020.2/PaddleOCR-release-2.5/train_center.pkl you can also try to add ace loss on your own dataset
  • ACELoss: weight: 0.1

    报错如下

    Traceback (most recent call last): File "./tools/train.py", line 191, in main(config, device, logger, vdl_writer) File "./tools/train.py", line 164, in main program.train(config, train_dataloader, valid_dataloader, device, model, File "C:\F\pycharm2020.2\PaddleOCR-release-2.5\tools\program.py", line 268, in train loss = loss_class(preds, batch) File "C:\Program Files\Python38\lib\site-packages\paddle\fluid\dygraph\layers.py", line 930, in call return self._dygraph_call_func(*inputs, *kwargs) File "C:\Program Files\Python38\lib\site-packages\paddle\fluid\dygraph\layers.py", line 915, in _dygraph_call_func outputs = self.forward(inputs, kwargs) File "C:\F\pycharm2020.2\PaddleOCR-release-2.5\ppocr\losses\combined_loss.py", line 55, in forward loss = loss_func(input, batch, kargs) File "C:\F\pycharm2020.2\PaddleOCR-release-2.5\ppocr\losses\center_loss.py", line 50, in call assert isinstance(predicts, (list, tuple)) AssertionError

image

2

是用的这个配置,但是我的其他配置是用rec_r34_vd_none_bilstm_ctc.yml

结果就报上面的错误了,不知道怎么解决

6
  1. 更新你的代码,你使用的是 2.5
  2. 把你的 configuration 和官方给的示例对齐,确保能跑后,再进行调整
1

2.6版 有的包装不了,没法用,polygon3之类 官方给的是mobilenet的,没有resnet34+centerloss的示例

4
  • resnet 和 mobilenet 的差异仅仅是 backbone
  • 尝试解决安装问题
1

resnet 和 mobilenet 可能有大区别。之前问distillation的问题,官方的人说要改源代码,才能让resnet34支持distillation, 但不知道怎么改。

SVTR模型,你训练没有?哪种模型准确率最高?

3
  • 你没使用 distillation 呀,所以在 configuration 里没差异的,对比一下就知道了
  • 场景不同精度不同
8

你使用center loss是在v2模型基础上用的?v3用不了吧? 能提高多少准确率?

9

自己训练的模型,没有使用预训练。基本不提升。

6

基本不提升,那没什么作用啊