Administrator
发布于 2025-04-19 / 7 阅读
0

CycleGAN官方仓库指南

csdn教程

github仓库

  1. git克隆仓库。

  2. 创建自己的二维数据。

  3. 在datasets路径执行combine命令。
    修改该文件中的两个路径参数:

    6803731fba7ba.png

    python3 combine_A_and_B.py --use_AB
    
  4. 回到main路劲执行下面的命令。

    1. 训练cyclegan:
      训练二维:

      python3 train.py --dataroot=../../data/preprocess_brain_slices/ct_and_mr --model=cycle_gan --input_nc=1 --output_nc=1 --dataset_mode=aligned --gpu_ids=0
      

      训练三维:

      python3 train.py --dataroot=../../data/preprocess_brain_volumes --model=cycle_gan3d --input_nc=1 --output_nc=1 --dataset_mode=unaligned3d --gpu_ids=0 --crop_size=96 --name=cyclegan_3d_brain
      
    2. 训练pix2pix:

      python3 train.py --dataroot=../../data/preprocess_brain_slices/ct_and_mr --model=pix2pix --input_nc=1 --output_nc=1 --dataset_mode=aligned --name pix2pix_2d_brain --gpu_ids=1
      

1. 命令参数

1.1 base options

  1. --datarootpath to images (should have subfolders trainA, trainB, valA, valB, etc)
    例:--dataroot=../../data/preprocess_brain_volumes
    那么训练时则会去找dataroot/trainA, dataroot/trainB。
    测试时去找dataroot/testA, dataroot/testB。
  2. --name:训练时会在checkpoints文件夹下创建name文件夹,并将训练时的东西保存在该文件夹下。
  3. --gpu_ids:选择gpu序号
  4. --model:选择模型。
  5. --input_nc:输入通道数。
  6. --output_nc:输出通道数。
  7. --dataset_mode:选择dataset。

1.2 train options

  1. --continue_train: 有这个参数,则训练为继续训练。
  2. --epoch_count: 开始训练的编号。
  3. --n_epochs: 总共训练epoch数。
  4. --n_epochs_decaynumber of epochs to linearly decay learning rate to zero。所以总共训练的epoch数为n_epochs+n_epochs_decay.

2. 训练

2.1 训练CycleGAN

训练二维模型:

python3 train.py --dataroot=../data/preprocess_brain_slices/ct_and_mr --model=cycle_gan --input_nc=1 --output_nc=1 --dataset_mode=aligned --gpu_ids=0 --name=cyclegan_2d_brain

训练三维模型:

python3 train.py --dataroot=../data/preprocess_brain_volumes --model=cycle_gan3d --input_nc=1 --output_nc=1 --dataset_mode=unaligned3d --gpu_ids=0 --crop_size=96 --name=cyclegan_3d_brain

继续训练三维模型:

三维显存耗费过大,所以将体数据剪裁为96x96x96。

python3 train.py --dataroot=../data/preprocess_brain_volumes --model=cycle_gan3d --input_nc=1 --output_nc=1 --dataset_mode=unaligned3d --gpu_ids=0 --crop_size=96 --name=cyclegan_3d_brain --continue_train --epoch_count=200 --n_epochs=1000

2.2 训练Pix2Pix

训练二维模型:

python3 train.py --dataroot=../data/preprocess_brain_slices/ct_and_mr --model=pix2pix --input_nc=1 --output_nc=1 --dataset_mode=aligned --name pix2pix_2d_brain --gpu_ids=1

2.3 训练ResViT

参考源码:https://github.com/icon-lab/ResViT/tree/main

首先进入到ResViT的项目文件夹:

然后下载预训练的ViT模型。

wget https://storage.googleapis.com/vit_models/imagenet21k/R50+ViT-B_16.npz &&
mkdir ../model/vit_checkpoint/imagenet21k &&
mv {MODEL_NAME}.npz ../model/vit_checkpoint/imagenet21k/R50-ViT-B_16.npz

预训练ART模块:

python3 train.py --dataroot ../data/preprocess_brain_slices/ct_and_mr --name ct_mr_pre_trained --gpu_ids 0 --model resvit_one --which_model_netG res_cnn --which_direction AtoB --lambda_A 100 --dataset_mode aligned --norm batch --pool_size 0 --output_nc 1 --input_nc 1 --loadSize 256 --fineSize 256 --niter 50 --niter_decay 50 --save_epoch_freq 5 --checkpoints_dir checkpoints/ --display_id 0 --lr 0.0002

微调:

python3 train.py --dataroot ../data/preprocess_brain_slices/ct_and_mr --name ct_mr_resvit --gpu_ids 0 --model resvit_one --which_model_netG resvit --which_direction AtoB --lambda_A 100 --dataset_mode aligned --norm batch --pool_size 0 --output_nc 1 --input_nc 1 --loadSize 256 --fineSize 256 --niter 25 --niter_decay 25 --save_epoch_freq 5 --checkpoints_dir checkpoints/ --display_id 0 --pre_trained_transformer 1 --pre_trained_resnet 1 --pre_trained_path checkpoints/ct_mr_pre_trained/latest_net_G.pth --lr 0.001

同样的训练pelvis,重复上两步。

预训练骨盆:

python3 train.py --dataroot ../data/preprocess_pelvis_slices/ct_and_mr --name ct_mr_pelvis_pre_trained --gpu_ids 0 --model resvit_one --which_model_netG res_cnn --which_direction AtoB --lambda_A 100 --dataset_mode aligned --norm batch --pool_size 0 --output_nc 1 --input_nc 1 --loadSize 256 --fineSize 256 --niter 50 --niter_decay 50 --save_epoch_freq 5 --checkpoints_dir checkpoints/ --display_id 0 --lr 0.0002

微调骨盆:

python3 train.py --dataroot ../data/preprocess_pelvis_slices/ct_and_mr --name ct_mr_pelvis_resvit --gpu_ids 0 --model resvit_one --which_model_netG resvit --which_direction AtoB --lambda_A 100 --dataset_mode aligned --norm batch --pool_size 0 --output_nc 1 --input_nc 1 --loadSize 256 --fineSize 256 --niter 25 --niter_decay 25 --save_epoch_freq 5 --checkpoints_dir checkpoints/ --display_id 0 --pre_trained_transformer 1 --pre_trained_resnet 1 --pre_trained_path checkpoints/ct_mr_pelvis_pre_trained/latest_net_G.pth --lr 0.001

2.3.Bugs.1 关于from scipy.misc import imresize 被弃用的解决办法

https://blog.csdn.net/Cr_NanMao/article/details/126238766

def scipy_misc_imresize(arr, size, interp='bilinear', mode=None):
   im = Image.fromarray(arr, mode=mode)
   ts = type(size)
   if np.issubdtype(ts, np.signedinteger):
      percent = size / 100.0
      size = tuple((np.array(im.size)*percent).astype(int))
   elif np.issubdtype(type(size), np.floating):
      size = tuple((np.array(im.size)*size).astype(int))
   else:
      size = (size[1], size[0])
   func = {'nearest': 0, 'lanczos': 1, 'bilinear': 2, 'bicubic': 3, 'cubic': 3}
   imnew = im.resize(size, resample=func[interp]) # 调用PIL库中的resize函数
   return np.array(imnew)

***将以上代码粘到需要运行的.py文件中,然后在imresize()处将“imresize”改成“scipy_misc_imresize”。


3. 测试

3.1 测试CycleGAN

测试三维模型:

python3 test_3d.py --dataroot ../data/preprocess_brain_volumes --model cycle_gan3d --input_nc 1 --output_nc 1 --dataset_mode unaligned3d --gpu_ids 0 --name cyclegan_3d_brain

3.2 测试ResViT

测试头颅:

python3 test.py --dataroot ../data/preprocess_brain_slices/ct_and_mr --name ct_mr_resvit --gpu_ids 0 --model resvit_one --which_model_netG resvit --dataset_mode aligned --norm batch --phase test --output_nc 1 --input_nc 1 --how_many 10000 --serial_batches --fineSize 256 --loadSize 256 --results_dir results/ --checkpoints_dir checkpoints/ --which_epoch latest

Bugs

Bugs.1 AttributeError: module 'torchvision.transforms' has no attribute 'InterpolationMode'

这个问题是因为torchvision的版本过低,我没用使用该仓库的环境。如果使用该仓库环境则服务器cuda版本与环境版本不一致,我导致to(device)卡住。

那么解决方法就是不使用 torchvision.transforms.InterpolationMode而是使用 PIL.Image.BILINEAR

将报错的地方都修改即可。

Bugs.2 AttributeError: partially initialized module 'charset_normalizer' has no attribute 'md__mypyc' (most likely due to a circular import)

https://github.com/chenfei-wu/TaskMatrix/issues/242

pip install --force-reinstall charset-normalizer==3.1.0