官方微调

1. image-text retrival 微调

命令

1
2
3
python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
--config ./configs/retrieval_coco.yaml \
--output_dir output/retrieval_coco

解析

  • retrieval_coco.yaml 文件如下

    • pretrained:改为下载好的权重
    • image_root:改为存放数据集图像的文件夹路径,并且该文件夹里要有三个子文件夹:test2014,train2014,val2014.(目的是与annotation文件里对于图像的路径保持一致)
    • ann_root: 改为存放json文件的路径,其中包括coco_karpathy_train.json,coco_karpathy_test,coco_karpathy_val.json

    如下图所示,如果需要改变微调的超参可直接在其中修改。

  • 微调设置

    • 优化器:AdamW
    • 学习率使用余弦学习率安排
    1
    2
    3
    4
    5
    def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
    """Decay the learning rate"""
    lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr
    for param_group in optimizer.param_groups:
    param_group['lr'] = lr
    • 微调参数量:vision encoder和bert encoder

2. image-text captioning 微调

1
python -m torch.distributed.run --nproc_per_node=8 train_caption.py 

3. VQA 微调

1
python -m torch.distributed.run --nproc_per_node=16 train_vqa.py 
  • 微调设置
    • 微调参数:vision encoder和bert decoder

4. NLVR2 微调

1
python -m torch.distributed.run --nproc_per_node=16 train_nlvr.py 

执行效果