import os

import argparse
import pdb
import os

def find_checkpoints(root_dir, name):
    ckpt_paths = []
    # 遍历trained_models目录下的所有文件和子目录
    for method_dir in os.listdir(root_dir):
        method_path = os.path.join(root_dir, method_dir)
        # 确保是目录
        if os.path.isdir(method_path):
            # 遍历每个method目录下的run目录
            for run_dir in os.listdir(method_path):
                run_path = os.path.join(method_path, run_dir)
                # 再次确保是目录
                if os.path.isdir(run_path):
                    # 查找所有包含name的.ckpt文件
                    for filename in os.listdir(run_path):
                        if filename.endswith('.ckpt'):
                            #file_names = filename.split('-')
                            #target_name = file_names[0] +'-' + file_names[1] 
                            if name in filename and '99' in filename:
                                ckpt_path = os.path.join(run_path, filename)
                                ckpt_paths.append(ckpt_path)
    return ckpt_paths

# 使用例子
# 假设你要找的目录名包含 'method1' 的所有 ckpt 文件

                
# l = ['byol','nnbyol','mocov2plus','mocov3','dino']
# l = ['byol','dino']

resnet18s = ['byol-cifar100_none_0.01','byol-cifar100_risp_0.001','byol-cifar100_so_1e-06','dino-cifar100_none_0.001','dino-cifar100_risp_0.001',
'dino-cifar100_so_1e-06','mocov2plus-cifar100_none_0.001''mocov2plus-cifar100_risp_0.001','mocov2plus-cifar100_so_1e-06','mocov3-cifar100_none_0.001','mocov3-cifar100_risp_0.001','mocov3-cifar100_so_1e-06'
'nnbyol-cifar100_none_0.001','nnbyol-cifar100_risp_0.001','nnbyol-cifar100_so_1e-06']
resnet50s = ['byol-cifar100_none_0.001_resnet50','byol-cifar100_risp_0.0001_resnet50','byol-cifar100_so_1e-06_resnet50','dino-cifar100_none_0.001_resnet50',
'dino-cifar100_risp_0.001_resnet50','dino-cifar100_so_1e-06_resnet50','mocov2plus-cifar100_none_0.001_resnet50','mocov2plus-cifar100_risp_0.001_resnet50','mocov2plus-cifar100_so_1e-06_resnet50',
'mocov3-cifar100_none_0.001_resnet50','mocov3-cifar100_risp_0.001_resnet50','mocov3-cifar100_so_1e-06_resnet50','nnbyol-cifar100_none_0.001_resnet50',
'nnbyol-cifar100_risp_0.001_resnet50','nnbyol-cifar100_so_1e-06_resnet50']
wideresnets = ['byol-cifar100_none_0.01_wide_resnet28w2','byol-cifar100_risp_0.0001_wide_resnet28w2','byol-cifar100_so_1e-07_wide_resnet28w2','dino-cifar100_none_1e-05_wide_resnet28w2',
'dino-cifar100_risp_1e-05_wide_resnet28w2','dino-cifar100_so_1e-06_wide_resnet28w2','mocov2plus-cifar100_none_1e-05_wide_resnet28w2','mocov2plus-cifar100_risp_1e-05_wide_resnet28w2','mocov2plus-cifar100_so_1e-06_wide_resnet28w2',
'mocov3-cifar100_none_1e-05_wide_resnet28w2','mocov3-cifar100_risp_1e-05_wide_resnet28w2','mocov3-cifar100_so_1e-06_wide_resnet28w2','nnbyol-cifar100_none_0.001_wide_resnet28w2',
'nnbyol-cifar100_risp_1e-05_wide_resnet28w2','nnbyol-cifar100_so_1e-06_wide_resnet28w2']


print(len(resnet18s),len(resnet50s),len(wideresnets))

finetune_names = resnet18s+resnet50s+wideresnets
finetune_datasets = 'cifar100'


finetune_names = ['byol-imagenet_none_0.0001_resnet50','byol-imagenet_risp_0.0001_resnet50']


for checkpoint in finetune_names:
    res = find_checkpoints(root_dir='./trained_models',name=checkpoint)
    weight = res[0]
    name = checkpoint + '({})'.format(finetune_datasets)
    print(weight)
    print(name)
    cmd = 'python3 -u main_linear.py  --config-path scripts/linear/cifar100 --config-name byol.yaml'.format(weight,name)
    print(cmd)
    os.system(cmd)
    # if len(res)!=1:
    #     pdb.set_trace()

# l =['byol']
# base = l[0]
# for base in l:
#     for method in ['none','risp','so']:
       
#         # if method=='none' and base =='mocov2plus':
#         #     continue
#      #   cmd = 'python3 -u main_pretrain.py  --config-path scripts/pretrain/cifar --config-name {}.yaml ++regular_method={} ++regular_weight={}'.format(base,method,weight)
#         cmd = 'python3 -u main_linear.py  --config-path scripts/linear/imagenet-100 --config-name {}.yaml ++regular_method={} ++regular_weight={}'.format(base,method,weight)
#         print(cmd)
#         os.system(cmd)
