-
git克隆仓库。
-
创建自己的二维数据。
-
在datasets路径执行combine命令。
修改该文件中的两个路径参数:python3 combine_A_and_B.py --use_AB
-
回到main路劲执行下面的命令。
-
训练cyclegan:
训练二维:python3 train.py --dataroot=../../data/preprocess_brain_slices/ct_and_mr --model=cycle_gan --input_nc=1 --output_nc=1 --dataset_mode=aligned --gpu_ids=0
训练三维:
python3 train.py --dataroot=../../data/preprocess_brain_volumes --model=cycle_gan3d --input_nc=1 --output_nc=1 --dataset_mode=unaligned3d --gpu_ids=0 --crop_size=96 --name=cyclegan_3d_brain
-
训练pix2pix:
python3 train.py --dataroot=../../data/preprocess_brain_slices/ct_and_mr --model=pix2pix --input_nc=1 --output_nc=1 --dataset_mode=aligned --name pix2pix_2d_brain --gpu_ids=1
-
1. 命令参数
1.1 base options
--dataroot
:path to images (should have subfolders trainA, trainB, valA, valB, etc)
例:--dataroot=../../data/preprocess_brain_volumes
那么训练时则会去找dataroot/trainA, dataroot/trainB。
测试时去找dataroot/testA, dataroot/testB。--name
:训练时会在checkpoints文件夹下创建name文件夹,并将训练时的东西保存在该文件夹下。--gpu_ids
:选择gpu序号--model
:选择模型。--input_nc
:输入通道数。--output_nc
:输出通道数。--dataset_mode
:选择dataset。
1.2 train options
- --continue_train: 有这个参数,则训练为继续训练。
- --epoch_count: 开始训练的编号。
- --n_epochs: 总共训练epoch数。
--n_epochs_decay
:number of epochs to linearly decay learning rate to zero。所以总共训练的epoch数为n_epochs+n_epochs_decay.
2. 训练
2.1 训练CycleGAN
训练二维模型:
python3 train.py --dataroot=../data/preprocess_brain_slices/ct_and_mr --model=cycle_gan --input_nc=1 --output_nc=1 --dataset_mode=aligned --gpu_ids=0 --name=cyclegan_2d_brain
训练三维模型:
python3 train.py --dataroot=../data/preprocess_brain_volumes --model=cycle_gan3d --input_nc=1 --output_nc=1 --dataset_mode=unaligned3d --gpu_ids=0 --crop_size=96 --name=cyclegan_3d_brain
继续训练三维模型:
三维显存耗费过大,所以将体数据剪裁为96x96x96。
python3 train.py --dataroot=../data/preprocess_brain_volumes --model=cycle_gan3d --input_nc=1 --output_nc=1 --dataset_mode=unaligned3d --gpu_ids=0 --crop_size=96 --name=cyclegan_3d_brain --continue_train --epoch_count=200 --n_epochs=1000
2.2 训练Pix2Pix
训练二维模型:
python3 train.py --dataroot=../data/preprocess_brain_slices/ct_and_mr --model=pix2pix --input_nc=1 --output_nc=1 --dataset_mode=aligned --name pix2pix_2d_brain --gpu_ids=1
2.3 训练ResViT
参考源码:https://github.com/icon-lab/ResViT/tree/main
首先进入到ResViT的项目文件夹:
然后下载预训练的ViT模型。
wget https://storage.googleapis.com/vit_models/imagenet21k/R50+ViT-B_16.npz &&
mkdir ../model/vit_checkpoint/imagenet21k &&
mv {MODEL_NAME}.npz ../model/vit_checkpoint/imagenet21k/R50-ViT-B_16.npz
预训练ART模块:
python3 train.py --dataroot ../data/preprocess_brain_slices/ct_and_mr --name ct_mr_pre_trained --gpu_ids 0 --model resvit_one --which_model_netG res_cnn --which_direction AtoB --lambda_A 100 --dataset_mode aligned --norm batch --pool_size 0 --output_nc 1 --input_nc 1 --loadSize 256 --fineSize 256 --niter 50 --niter_decay 50 --save_epoch_freq 5 --checkpoints_dir checkpoints/ --display_id 0 --lr 0.0002
微调:
python3 train.py --dataroot ../data/preprocess_brain_slices/ct_and_mr --name ct_mr_resvit --gpu_ids 0 --model resvit_one --which_model_netG resvit --which_direction AtoB --lambda_A 100 --dataset_mode aligned --norm batch --pool_size 0 --output_nc 1 --input_nc 1 --loadSize 256 --fineSize 256 --niter 25 --niter_decay 25 --save_epoch_freq 5 --checkpoints_dir checkpoints/ --display_id 0 --pre_trained_transformer 1 --pre_trained_resnet 1 --pre_trained_path checkpoints/ct_mr_pre_trained/latest_net_G.pth --lr 0.001
同样的训练pelvis,重复上两步。
预训练骨盆:
python3 train.py --dataroot ../data/preprocess_pelvis_slices/ct_and_mr --name ct_mr_pelvis_pre_trained --gpu_ids 0 --model resvit_one --which_model_netG res_cnn --which_direction AtoB --lambda_A 100 --dataset_mode aligned --norm batch --pool_size 0 --output_nc 1 --input_nc 1 --loadSize 256 --fineSize 256 --niter 50 --niter_decay 50 --save_epoch_freq 5 --checkpoints_dir checkpoints/ --display_id 0 --lr 0.0002
微调骨盆:
python3 train.py --dataroot ../data/preprocess_pelvis_slices/ct_and_mr --name ct_mr_pelvis_resvit --gpu_ids 0 --model resvit_one --which_model_netG resvit --which_direction AtoB --lambda_A 100 --dataset_mode aligned --norm batch --pool_size 0 --output_nc 1 --input_nc 1 --loadSize 256 --fineSize 256 --niter 25 --niter_decay 25 --save_epoch_freq 5 --checkpoints_dir checkpoints/ --display_id 0 --pre_trained_transformer 1 --pre_trained_resnet 1 --pre_trained_path checkpoints/ct_mr_pelvis_pre_trained/latest_net_G.pth --lr 0.001
2.3.Bugs.1 关于from scipy.misc import imresize 被弃用的解决办法
https://blog.csdn.net/Cr_NanMao/article/details/126238766
def scipy_misc_imresize(arr, size, interp='bilinear', mode=None):
im = Image.fromarray(arr, mode=mode)
ts = type(size)
if np.issubdtype(ts, np.signedinteger):
percent = size / 100.0
size = tuple((np.array(im.size)*percent).astype(int))
elif np.issubdtype(type(size), np.floating):
size = tuple((np.array(im.size)*size).astype(int))
else:
size = (size[1], size[0])
func = {'nearest': 0, 'lanczos': 1, 'bilinear': 2, 'bicubic': 3, 'cubic': 3}
imnew = im.resize(size, resample=func[interp]) # 调用PIL库中的resize函数
return np.array(imnew)
***将以上代码粘到需要运行的.py文件中,然后在imresize()处将“imresize”改成“scipy_misc_imresize”。
3. 测试
3.1 测试CycleGAN
测试三维模型:
python3 test_3d.py --dataroot ../data/preprocess_brain_volumes --model cycle_gan3d --input_nc 1 --output_nc 1 --dataset_mode unaligned3d --gpu_ids 0 --name cyclegan_3d_brain
3.2 测试ResViT
测试头颅:
python3 test.py --dataroot ../data/preprocess_brain_slices/ct_and_mr --name ct_mr_resvit --gpu_ids 0 --model resvit_one --which_model_netG resvit --dataset_mode aligned --norm batch --phase test --output_nc 1 --input_nc 1 --how_many 10000 --serial_batches --fineSize 256 --loadSize 256 --results_dir results/ --checkpoints_dir checkpoints/ --which_epoch latest
Bugs
Bugs.1 AttributeError: module 'torchvision.transforms' has no attribute 'InterpolationMode'
这个问题是因为torchvision的版本过低,我没用使用该仓库的环境。如果使用该仓库环境则服务器cuda版本与环境版本不一致,我导致to(device)卡住。
那么解决方法就是不使用 torchvision.transforms.InterpolationMode
而是使用 PIL.Image.BILINEAR
。
将报错的地方都修改即可。
Bugs.2 AttributeError: partially initialized module 'charset_normalizer' has no attribute 'md__mypyc' (most likely due to a circular import)
https://github.com/chenfei-wu/TaskMatrix/issues/242
pip install --force-reinstall charset-normalizer==3.1.0