# python run_in_script_8X7.py --checkpoints2 --models_and_datas=models_and_datas_full_qwen_7b --smaller_batch=1 --output=03107b --tests=base8X7whole
import os,glob,threading,subprocess
import torch, itertools
from models_and_datas import get_models_and_datas
from huggingface_hub import login
# #
import time,argparse,random,json
from conflict import conflict_between

parser = argparse.ArgumentParser()
parser.add_argument("--models_and_datas", type=str, default="models_and_datas_full_qwen_3b")
parser.add_argument("--smaller_batch", type=int, default=1)
parser.add_argument("--tests", type=str, default='')
parser.add_argument("--output", type=str, default='0310_3B')
parser.add_argument("--checkpoints2", action='store_true', default=False)
parser.add_argument("--prefer_merge", action='store_true', default=False)
parser.add_argument("--doppel_merge", action='store_true', default=False)
parser.add_argument("--doppel_linear", action='store_true', default=False)
parser.add_argument("--shigure_merge", action='store_true', default=False)
parser.add_argument("--count", action='store_true', default=False)
parser.add_argument("--twin_merge",action='store_true', default=False)
args = parser.parse_args()
try:
    models_and_datas = get_models_and_datas(args.models_and_datas)
except:
    models_and_datas = None

print("Number of GPUs:", torch.cuda.device_count())
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        
gpus = int(torch.cuda.device_count())

output_folder = "output/test0301_full"

threads = {}
for i in range(gpus):
    threads[i] = None

def simple_run(bs, method, device, extra = ''):
    cmd = f' python run.py --device {device} --branches {" ".join(bs)} --method {method}  --myllama '
    cmd += extra
    if 't5' in args.models_and_datas:
        cmd = cmd.replace('run','run_t5')
    cmd += f'  --output_folder={output_folder}  --base_model={models_and_datas["base"]["model"][0]}'
    
    print(cmd)
    subprocess.call([cmd],shell=True)

def get_running(target, args, extra = ''):
    while True:
        Done = False
        for k in range(gpus):
            if threads[k] is None or not threads[k].is_alive():
                thread = threading.Thread(target=target, \
                    args = tuple(list(args) + [f'cuda:{k}', extra]))
                # thread = threading.Thread(target=target, \
                    # args = tuple(list(args) + [f'cpu', extra]))
                threads[k] = thread
                threads[k].setDaemon(True)
                threads[k].start()
                Done = True
                break
        if Done:
            break
        # time.sleep(60)

if 'conflict' in args.tests:
    branches = ['sst2','cola','rte', 'qnli','mnli','mrpc','wnli','qqp']
    output_folder = f"output/{args.output}_conflicts"
    os.makedirs(output_folder, exist_ok=True)

    branches = sorted(branches)
        
    for i in range(len(branches)):
        for j in range(i+1,len(branches)):
            with open(output_folder+f'/{branches[i]}_{branches[j]}.txt','w+') as f:

                # f.write(f'conflict between {branches[i]} and {branches[j]}:\n')
                ma = models_and_datas[branches[i]]['model'][0]
                mb = models_and_datas[branches[j]]['model'][0]
                bm = models_and_datas['base']['model'][0]
                answer = conflict_between(ma,mb,bm)
                json_object = json.dumps(answer, indent=4)
                f.write(json_object)


if 'base' in args.tests:
    output_folder = f"output/{args.output}_simple"
    methods = ['simple']
    # branches = ['sst2','cola','rte', 'qnli','mnli','mrpc','wnli','qqp']
    branches = list(models_and_datas.keys())[1:]
    base_extra = f'  --myllama   --no_storage --smaller_batch={args.smaller_batch} --models_and_datas={args.models_and_datas} '
    if args.checkpoints2:
        base_extra += ' --checkpoints2 '
    branches = sorted(branches)
    methods = sorted(methods)
    
    
    for i in range(len(branches)):
        for m in ['simple']:
            extra = base_extra
            get_running(simple_run, ([branches[i]],m,), extra )
            
    
    
    # for i in range(len(branches)):
    #     extra = base_extra + f'  --base_test  --models_and_datas={args.models_and_datas}  '
    #     get_running(simple_run, ([branches[i]],'base',), extra )
            
            
if '8X7' in args.tests:
    output_folder = f"output/{args.output}_8X7"
    methods = ['ties', 'dare_ties', 'linear', 'slerp', 'task_arithmetic']
    # branches = ['sst2','cola','rte', 'qnli','mnli','mrpc','wnli','qqp']
    base_extra = f'  --myllama   --no_storage --smaller_batch={args.smaller_batch} --models_and_datas={args.models_and_datas} '
    if args.checkpoints2:
        base_extra += ' --checkpoints2 '
    if args.prefer_merge:
        output_folder = output_folder + 'prefer'
        base_extra += '  --prefer_merge '
    if args.doppel_merge:
        if args.doppel_linear:
            methods = ['linear','ties', 'task_arithmetic']
            base_extra += '  --doppel_linear'
        base_extra += '  --doppel_merge '
        
    branches = list(models_and_datas.keys())
    real_branches = []
    for i in branches:
        if i!='base':
            real_branches += [i]
    branches = real_branches
        
    branches = sorted(branches)
    methods = sorted(methods)
            
    for i in range(len(branches)):
        for j in range(i+1,len(branches)):
            for m in methods:
                extra = base_extra
                get_running(simple_run, ([branches[i],branches[j]],m,), extra )
                
    
            
    # for i in range(len(branches)):
    #     for j in range(i+1,len(branches)):
    #         for m in ['run_a_on_b']:
    #             extra = base_extra + ' --run_a_on_b  '
    #             get_running(simple_run, ([branches[i],branches[j]],'rab',), extra ) 
    #             get_running(simple_run, ([branches[j],branches[i]],'rab',), extra ) 
    
if 'everything' in args.tests:
    output_folder = f"output/{args.output}_everything"
    methods = ['ties']
    # branches = ['sst2','cola','rte', 'qnli','mnli','mrpc','wnli','qqp']
    branches = list(models_and_datas.keys())[1:]
    base_extra = f'  --myllama   --no_storage --smaller_batch={args.smaller_batch}   --models_and_datas={args.models_and_datas}   '
    if args.checkpoints2:
        base_extra += ' --checkpoints2 '
    if args.prefer_merge:
        output_folder = output_folder + 'prefer'
        base_extra += '  --prefer_merge '
    if args.doppel_merge:
        output_folder = output_folder + 'doppel'
        base_extra += '  --doppel_merge '
        methods = ['ties']
        
        
            
    branches = sorted(branches)
    methods = sorted(methods)
    
    os.makedirs(output_folder,exist_ok=True)
    for m in methods:
        extra = base_extra +f' '
        get_running(simple_run, (branches,m,), extra )

    
    
    # for i in range(len(branches)):
    #     for m in ['simple']:
    #         extra = base_extra
    #         get_running(simple_run, ([branches[i]],m,), extra )


if 'whole' in args.tests:
    output_folder = f"output/{args.output}_whole"
    methods = ['ties', 'dare_ties', 'linear', 'multislerp', 'task_arithmetic']
    branches = ['sst2','cola','rte', 'qnli','mnli','mrpc','wnli','qqp']
    base_extra = f'  --myllama   --no_storage --smaller_batch={args.smaller_batch}   --models_and_datas={args.models_and_datas}   '
    if args.checkpoints2:
        base_extra += ' --checkpoints2 '
    if args.prefer_merge:
        output_folder = output_folder + 'prefer'
        base_extra += '  --prefer_merge '
    if args.doppel_merge:
        output_folder = output_folder + 'doppel'
        base_extra += '  --doppel_merge '
        methods = ['ties']
    if args.shigure_merge:
        output_folder = output_folder + 'shigure'
        base_extra += '  --shigure_merge '
        methods = ['shigure']
    if args.twin_merge:
        output_folder = output_folder + 'twin'
        base_extra += '  --twin_merge '
        methods = ['twin']
    if args.count:
        base_extra += '  --count '
        
        
            
    branches = sorted(branches)
    methods = sorted(methods)
    
    os.makedirs(output_folder,exist_ok=True)
    for m in methods:
        extra = base_extra +f' '
        get_running(simple_run, (branches,m,), extra )


if 'svdtest' in args.tests:
    output_folder = f"output/{args.output}_whole"
    methods = ['ties', 'dare_ties', 'linear', 'multislerp', 'task_arithmetic']
    branches = ['sst2','cola','rte', 'qnli','mnli','mrpc','wnli','qqp']
    base_extra = f'  --myllama   --no_storage --smaller_batch={args.smaller_batch}   --models_and_datas={args.models_and_datas}   '
    if args.checkpoints2:
        base_extra += ' --checkpoints2 '
    if args.prefer_merge:
        output_folder = output_folder + 'prefer'
        base_extra += '  --prefer_merge '
    if args.doppel_merge:
        output_folder = output_folder + 'doppel'
        base_extra += '  --doppel_merge '
        methods = ['ties']
    if args.shigure_merge:
        output_folder = output_folder + 'shigure'
        base_extra += '  --shigure_merge '
        methods = ['shigure']
    if args.count:
        base_extra += '  --count '
        
            
    branches = sorted(branches)
    methods = sorted(methods)
    
    os.makedirs(output_folder,exist_ok=True)
    for m in methods:
        for k in [10,20,30,40,50,60,70,80]:
            extra = base_extra +f' --svd_rank={k}'
            output_folder = output_folder + f'svd{k}'
            os.makedirs(output_folder,exist_ok=True)
            get_running(simple_run, (branches,m,), extra )

if 'mixedcg' in args.tests:
    output_folder = f"output/{args.output}_mixedcg"
    methods = ['ties', 'dare_ties', 'linear', 'slerp', 'task_arithmetic']
    branches = ['sst2','cola','qqp', 'mnli','defect_detection','text_to_code','code_to_text']
    base_extra = f'  --myllama   --no_storage --smaller_batch={args.smaller_batch} --models_and_datas={args.models_and_datas} '
    if args.checkpoints2:
        base_extra += ' --checkpoints2 '
    if args.prefer_merge:
        output_folder = output_folder + 'prefer'
        base_extra += '  --prefer_merge '
        
        
        
    for i in range(4):
        for j in range(4,7):
            for m in methods:
                extra = base_extra
                get_running(simple_run, ([branches[i],branches[j]],m,), extra )
                
if '4in1cg' in args.tests:
    output_folder = f"output/{args.output}_whole"
    methods = ['ties', 'dare_ties', 'linear', 'multislerp', 'task_arithmetic']
    branches = ['code_to_text', 'text_to_code', 'defect_detection', 'clone_detection']
    base_extra = f'  --myllama   --no_storage --smaller_batch={args.smaller_batch}   --models_and_datas={args.models_and_datas}   '
    if args.checkpoints2:
        base_extra += ' --checkpoints2 '
    if args.prefer_merge:
        output_folder = output_folder + 'prefer'
        base_extra += '  --prefer_merge '
    if args.doppel_merge:
        output_folder = output_folder + 'doppel'
        base_extra += '  --doppel_merge '
        methods = ['ties']
    if args.shigure_merge:
        output_folder = output_folder + 'shigure'
        base_extra += '  --shigure_merge '
        methods = ['shigure']
    if args.count:
        base_extra += '  --count '
            
    branches = sorted(branches)
    methods = sorted(methods)
    
    os.makedirs(output_folder,exist_ok=True)
    for m in methods:
        extra = base_extra +f'  > {output_folder}/full_{m}.txt '
        get_running(simple_run, (branches,m,), extra )

if '16cg' in args.tests:
    output_folder = f"output/{args.output}_16cg"
    methods = ['ties']
    if 'dlc1' in args.tests:
        methods = ['task_arithmetic']
    if 'dlc2' in args.tests:
        methods = ['dare_ties']
        
    bra = ['code_to_text', 'text_to_code', 'defect_detection', 'clone_detection']
    base_extra = f'  --myllama   --no_storage --smaller_batch={args.smaller_batch}   --models_and_datas={args.models_and_datas}   '
    if args.checkpoints2:
        base_extra += ' --checkpoints2 '
    if args.prefer_merge:
        output_folder = output_folder + 'prefer'
        base_extra += '  --prefer_merge '
        
    bra = sorted(bra)
    methods = sorted(methods)
    if 'dlc0' in args.tests:

        for r in [1]:
            for branches in itertools.combinations(bra,r):
                for m in methods:
                    extra = base_extra 
                    get_running(simple_run, (list(branches),m,), extra )
                    
    else:
        for r in [1,2,3,4]:
            for branches in itertools.combinations(bra,r):
                for m in methods:
                    extra = base_extra 
                    get_running(simple_run, (list(branches),m,), extra )
                    
                    
if 'random46' in args.tests:
    for r in [4,6]:
        output_folder = f"output/{args.output}_random{r}"
        methods = ['ties', 'dare_ties', 'linear', 'multislerp', 'task_arithmetic']
        # bra1 = ['sst2','cola','rte', 'qnli','mnli','mrpc','wnli','qqp']
        bra = ['sst2','cola','rte', 'qnli','mnli','mrpc','qqp']
        base_extra = f'  --myllama   --no_storage --smaller_batch={args.smaller_batch}   --models_and_datas={args.models_and_datas}   '
        if args.checkpoints2:
            base_extra += ' --checkpoints2 '
        if args.prefer_merge:
            output_folder = output_folder + 'prefer'
            base_extra += '  --prefer_merge '
            
        
                
        bra = sorted(bra)
        methods = sorted(methods)

        list_wnli = list(itertools.combinations(bra,r-1))
        list_no_wnli = list(itertools.combinations(bra,r))
        
        random.seed(42)
        random.shuffle(list_wnli)
        random.shuffle(list_no_wnli)
        
        for branches in (list_wnli[:5] + list_no_wnli[:5]):
            for m in methods:
                extra = base_extra 
                get_running(simple_run, (list(branches) if len(list(branches))==r else list(branches)+['wnli'],\
                                         m,), extra )
                        

if 'increasing' in args.tests:
    output_folder = f"output/{args.output}_whole"
    methods = ['ties']
    branches = ['sst2','cola','rte', 'qnli','mnli','mrpc','wnli','qqp']
    base_extra = f'  --myllama   --no_storage --smaller_batch={args.smaller_batch}   --models_and_datas={args.models_and_datas}   '
    if args.checkpoints2:
        base_extra += ' --checkpoints2 '
    if args.doppel_merge:
        output_folder = output_folder + 'doppel'
        base_extra += '  --doppel_merge '
        methods = ['ties']
    if args.count:
        base_extra += '  --count '
        
    branches = sorted(branches)
    methods = sorted(methods)
    
    m = 'ties'
    for i in range(1,9):
        extra = base_extra +f' '
        get_running(simple_run, (branches[:i],m,), extra )
        
if '88lora' in args.tests:
    nagato = args.models_and_datas
    for lora_set_id in range(1,9):
        args.models_and_datas = f'{nagato}{lora_set_id}'
        models_and_datas = get_models_and_datas(f'{nagato}{lora_set_id}')
        output_folder = f"output/{args.output}/lora{lora_set_id}/whole"
        methods = ['linear','ties', 'task_arithmetic']
        branches = list(models_and_datas.keys())
        real_branches = []
        for i in branches:
            if i!='base':
                real_branches += [i]
        branches = real_branches
        base_extra = f'  --myllama   --no_storage --smaller_batch={args.smaller_batch}   --models_and_datas={args.models_and_datas}   '

        assert args.doppel_merge
        assert args.doppel_linear
        base_extra += '  --doppel_merge '
        base_extra += '  --doppel_linear '
        
            
        branches = sorted(branches)
        
        for m in methods:
            extra = base_extra +f' '
            get_running(simple_run, (branches[:],m,), extra )
        
        
        for i in range(len(branches)):
            for m in ['simple']:
                extra = base_extra +f' '
                get_running(simple_run, ([branches[i]],m,), extra )


if 'cut' in args.tests:
    
    for nagato in ['models_and_datas_qwen_cut1','models_and_datas_qwen_cut2',\
                   'models_and_datas_qwen_cut3','models_and_datas_qwen_cut4',\
                    'models_and_datas_lora_cut_notcut','models_and_datas_lora_cut1','models_and_datas_lora_cut2',\
                    'models_and_datas_lora_cut2_notcut','models_and_datas_lora_cut2_cut1','models_and_datas_lora_cut2_ut2',\
                        ]:
        args.models_and_datas = nagato
        models_and_datas = get_models_and_datas(nagato)
        output_folder = f"output/{args.output}/{nagato}/whole"
        methods = ['linear','ties', 'task_arithmetic']
        branches = list(models_and_datas.keys())
        real_branches = []
        for i in branches:
            if i!='base':
                real_branches += [i]
        branches = real_branches
        base_extra = f'  --myllama   --no_storage --smaller_batch={args.smaller_batch}   --models_and_datas={args.models_and_datas}   '


        if 'lora' in nagato:
            base_extra += '  --doppel_merge '
            base_extra += '  --doppel_linear '
        
            
        branches = sorted(branches)
        
        for m in methods:
            extra = base_extra +f' '
            get_running(simple_run, (branches[:],m,), extra )
        
        
        for i in range(len(branches)):
            for m in ['simple']:
                extra = base_extra +f' '
                get_running(simple_run, ([branches[i]],m,), extra )

for k in range(gpus):
    t = threads[k]
    if t is not None and t.is_alive():
        t.join()


