import glob, os, yaml, subprocess
from evaluations import load_ckpt, test, test_glue
import argparse
from models_and_datas import get_models_and_datas
import numpy as np
import time, shutil, ast, torch
from prefer_merge import solve_prefer_merge, solve_doppel_merge, solve_shigure_merge, solve_twin_merge
from transformers import AutoModelForCausalLM, AutoTokenizer
import os,time
from peft import PeftModel

def signator(x):
    if 'defect_detection' in x:
        return 'defect'
    if 'clone_detection' in x:
        return 'clone'
    return x.replace('/','--')[-5:]

def acree_merge(list_of_ckpts, output_path, args):
    merge_method = args.method
    device = args.device
    # list_of_ckpts's first ckpt is base ckpt!
    # while weight_list first i normal
    # parameters:
    #   density: 0.5
    #   weight: [0, 0.3, 0.7, 1]
    
    
    # actually not gonna work because parameters mat change!!!!
    # actually not gonna work because parameters mat change!!!!
    # actually not gonna work because parameters mat change!!!!
    if os.path.exists(output_path+'/config.json'):
        return 
    # actually not gonna work because parameters mat change!!!!
    # actually not gonna work because parameters mat change!!!!
    # actually not gonna work because parameters mat change!!!!
    
    def choice_nan(base_v, possible_v):
        return base_v if np.isnan(possible_v) else possible_v
    
    def parameterize(ckpt, weight):
        if merge_method in ['ties']:
            parameter = {'density':choice_nan(0.2,args.density),'weight':choice_nan(1.0,weight)}
        elif merge_method in ['dare_ties']:
            parameter = {'density':choice_nan(0.1,args.density),'weight':choice_nan(1.0,weight)}
        elif merge_method in ['linear']:
            parameter = {'weight':choice_nan(1.0,weight)}
        elif merge_method in ['task_arithmetic']:
            parameter = {'weight':choice_nan(1.0,weight),'lambda':choice_nan(0.4,args.density)}
        elif merge_method in ['dare_linear']:
            parameter = {'density':choice_nan(0.1,args.density),'weight':choice_nan(1.0,weight)}
        elif merge_method in ['slerp']:
            parameter = None 
        elif merge_method in ['multislerp']:
            parameter = {'weight':choice_nan(1.0,weight)}
        else:
            assert 0, 'Merge method not recognized!'
        if parameter is not None:
            return {'model': ckpt, 'parameters': parameter}
        else:
            return {'model': ckpt}
        
    py_object = {
        'models': [parameterize(ckpt, weight) \
                   for ckpt,weight in zip(list_of_ckpts[1:], args.weight_list)],
        'merge_method': merge_method, 
        # 'base_model': list_of_ckpts[0], 
        'parameters': {'normalize': True, 'int8_mask': True}, 
        'dtype': 'bfloat16'}
    
    if merge_method in ['ties', 'dare_ties',  'dare_linear','task_arithmetic']:
        py_object['base_model'] = list_of_ckpts[0]
    elif merge_method in ['slerp']:
        py_object['base_model'] = list_of_ckpts[1]
        py_object['parameters'] = {'t':[{'value':0.5}]}
    
    os.makedirs('temporate',exist_ok=True)
    sid = time.time()
    file = open(f'temporate/{sid}_{device[-1]}.yml', 'w', encoding='utf-8')
    yaml.dump(py_object, file)
    file.close()
    if device == 'cpu':
        acree_cmd = f'mergekit-yaml temporate.yml {output_path}  --low-cpu-memory '
    else:
        acree_cmd = f'CUDA_VISIBLE_DEVICES={device[-1]} mergekit-yaml temporate/{sid}_{device[-1]}.yml {output_path}  --cuda --read-to-gpu   --low-cpu-memory '
        
    subprocess.call([acree_cmd],shell=True)
    # print(acree_cmd)

def model_name_to_ckpt(model_name):
    path = 'checkpoints/hub/models--'+model_name.replace('/','--')
    path += '/snapshots'
    
    if args.checkpoints2:
        path = path.replace('checkpoints','checkpoints2')
    if 'my-' in model_name or 'sft' in model_name:
        path = 'checkpoints/hub/'+model_name.replace('/','--')
        
        if args.checkpoints2:
            path = path.replace('checkpoints','checkpoints2')
        path += '/snapshots'
        path = '../llama_on_glue/' + path
        paths = os.listdir(path)
        path = path+'/'+paths[0]
        paths = os.listdir(path)

        m_paths = []
        for i in paths:
            i = i.replace('checkpoint-','')
            m_paths += [int(i)]
            
        path = path+'/checkpoint-'+str(sorted(m_paths)[-args.steps_rank])
        print(path)
    else:
        if args.myllama:
            path = '../llama_on_glue/' + path
        paths = os.listdir(path)
        path = path+'/'+paths[0]
        print(path)
    
    return path

def simple_model_name_to_ckpt(model_name):
    path = 'checkpoints/hub/models--'+model_name.replace('/','--')
    path += '/snapshots'
    
    if 'my-' in model_name:
        path = 'checkpoints/hub/'+model_name.replace('/','--')
        
        path += '/snapshots'
        path = '../llama_on_glue/' + path
        paths = os.listdir(path)
        path = path+'/'+paths[0]
        paths = os.listdir(path)

        m_paths = []
        for i in paths:
            i = i.replace('checkpoint-','')
            m_paths += [int(i)]
            
        path = path+'/checkpoint-'+str(sorted(m_paths)[-1])
        print(path)
    else:
        if 1:
            path = '../llama_on_glue/' + path
        paths = os.listdir(path)
        path = path+'/'+paths[0]
        print(path)
    
    return path


def model_names_to_output(model_names, args):
    ID = model_names[0].split('/')[0]
    path = f'checkpoints/merged/{args.method}__' + ID + '+'.join(map(signator,model_names))
    if args.sequential:
        path = f'checkpoints/merged/sequential_{args.method}__' + ID + '0'.join(map(signator,model_names))
    
    if args.checkpoints2:
        path = path.replace('checkpoints','checkpoints2')
    
    if not np.isnan(args.density):
        path = path+f'--d={args.density}'
        
    if args.myllama:
        path = '../llama_on_glue/' + path
        
    if args.no_storage:
        sid = time.time()
        path = path.replace('merged',f'merged/{sid}')
        args.erase_path = path[:path.find('merged')]+f'merged/{sid}'
    
    os.makedirs(path, exist_ok=True)
    return path + '/'
    
    
if __name__ == '__main__':   
    os.makedirs('partial',exist_ok=True) 
    # file = open('temporate.yml', 'r', encoding="utf-8")
    # file_data = file.read()
    # data = yaml.load(file_data,Loader=yaml.FullLoader)
    # print(data)
    
    # python run.py --branches code --device cuda:0
    # python run.py --branches guard korean --device cuda:0 --method=dare_linear
    # python run.py --branches guard korean italian --device cpu --method=linear --sequential
    parser = argparse.ArgumentParser()
    parser.add_argument('--branches', nargs='+', required=True)
    parser.add_argument("--method", type=str, default="ties")
    parser.add_argument("--output_folder", type=str, default="output/test0123")
    parser.add_argument("--device", type=str, required=True,default='cpu')
    parser.add_argument("--base_model", type=str, default="meta-llama/Meta-Llama-3-8B")
    parser.add_argument("--no_eval", action='store_true', default=False)
    parser.add_argument("--sequential", action='store_true', default=False)
    parser.add_argument("--myllama", action='store_true', default=True)
    parser.add_argument("--run_a_on_b", action='store_true', default=False)
    parser.add_argument("--base_test", action='store_true', default=False)
    parser.add_argument("--smaller_batch", type=int, default=1)
    parser.add_argument("--models_and_datas", type=str, required=True)
    parser.add_argument("--checkpoints2",action='store_true', default=False)
    parser.add_argument("--prefer_merge",action='store_true', default=False)
    parser.add_argument("--doppel_merge",action='store_true', default=False)
    parser.add_argument("--doppel_linear",action='store_true', default=False)
    parser.add_argument("--shigure_merge",action='store_true', default=False)
    parser.add_argument("--count", action='store_true', default=False)
    parser.add_argument("--svd_rank", type=int, default=-1)
    parser.add_argument("--twin_merge",action='store_true', default=False)
    
    
    
    parser.add_argument("--no_storage", action='store_true', default=False)
    
    
    parser.add_argument("--density", type=float, default=np.nan)
    parser.add_argument("--steps_rank", type=int, default=1)
    parser.add_argument('--weight_list', type=str, default='')
    args = parser.parse_args()
    
    models_and_datas = get_models_and_datas(args.models_and_datas)
    if len(args.weight_list):
        args.weight_list = ast.literal_eval(args.weight_list)
    else:
        args.weight_list = [np.nan for i in args.branches]
    if len(args.weight_list)!=len(args.branches):
        assert 0, 'Number of Branches not equal to Number of Weights!'
        
    model_names = []
    datasets = []
    print(args.branches)
    
    for i in args.branches:
        model_names += models_and_datas[i]['model']
        datasets += models_and_datas[i]['datasets']
    if len(args.branches)!=len(model_names):
        assert 0, 'Some branches not recognized!'
        
    if args.base_test:
        model_names = models_and_datas['base']['model']
        
    if len(model_names) == 1:
        # no need to merge
        output_ckpt = model_name_to_ckpt(model_names[0])
    elif args.run_a_on_b:
        # test first model on dataset b
        output_ckpt = model_name_to_ckpt(model_names[0])
        datasets = datasets[1:]
    # elif args.sequential:
    #     merged_models = model_names[:2]
    #     output_ckpt = model_names_to_output(merged_models, args)
    #     model_names_partial = [args.base_model] + merged_models
    #     ckpts = list(map(model_name_to_ckpt, model_names_partial))
    #     acree_merge(ckpts, output_ckpt, args)
        
    #     base_ckpt = model_name_to_ckpt(args.base_model)
    #     for i in model_names[2:]:
    #         merged_models += [i]
    #         output_ckpt2 = model_names_to_output(merged_models, args)
    #         ckpts = [base_ckpt, output_ckpt, model_name_to_ckpt(i)]
    #         print(ckpts)
    #         acree_merge(ckpts, output_ckpt2, args)
    #         output_ckpt = output_ckpt2
    elif not args.shigure_merge and not args.doppel_merge and not args.twin_merge:    
        # list_of_ckpts's first ckpt is base ckpt!
        model_names = [args.base_model] + model_names
        ckpts = list(map(model_name_to_ckpt, model_names))
        output_ckpt = model_names_to_output(model_names, args)

        if 1:
            if args.no_storage:
                model_names = [model_name.replace('/','--') for model_name in model_names]
                need_to_run = 0
                for data in datasets:
                    folder_file = f'{args.output_folder}/{args.method}/{"-".join(map(signator,model_names))}'
                    if not np.isnan(args.weight_list[0]):
                        folder_file = folder_file.replace(f'{args.method}',f'{args.method}--weight_list={args.weight_list}')
                    if not np.isnan(args.density):
                        folder_file = folder_file.replace(f'{args.method}',f'{args.method}--density={args.density}')
                    if not np.isnan(args.steps_rank):
                        folder_file = folder_file.replace(f'{args.method}',f'{args.method}--steps_rank={args.steps_rank}')
                    txt_file = folder_file +f'/{data}.txt'
                    if os.path.exists(txt_file):
                        print(f'aleardy evaled on {data}!')
                    else:
                        need_to_run = 1
                if need_to_run:
                    acree_merge(ckpts, output_ckpt, args)
            elif not args.doppel_merge:
                acree_merge(ckpts, output_ckpt, args)
    else:
        model_names = [args.base_model] + model_names
        
        
        
    if not args.no_eval:

        ckpts = list(map(model_name_to_ckpt, model_names))
        # if not only merge is required
        
        model_names = [model_name.replace('/','--') for model_name in model_names]
        cached = False
        cache_model = None
        
        if args.shigure_merge and len(args.branches)>1:
            # shigure merge can do at first;
            
            model, overhead = solve_shigure_merge(None, ckpts[1:], args.base_model, args)
            tokenizer = AutoTokenizer.from_pretrained(ckpts[0])


        for i,data in enumerate(datasets):
            # if args.spj and data!='humaneval_tiny' and data!='gsm8k':
            #     continue
            folder_file = f'{args.output_folder}/{args.method}/{"-".join(map(signator,model_names))}'
            if not np.isnan(args.weight_list[0]):
                folder_file = folder_file.replace(f'{args.method}',f'{args.method}--weight_list={args.weight_list}')
            if not np.isnan(args.density):
                folder_file = folder_file.replace(f'{args.method}',f'{args.method}--density={args.density}')
            if not np.isnan(args.steps_rank):
                folder_file = folder_file.replace(f'{args.method}',f'{args.method}--steps_rank={args.steps_rank}')
            
            txt_file = folder_file +f'/{i}.txt'
            if os.path.exists(txt_file):
                print(f'aleardy evaled on {i}!')
                continue
            else:
                print(txt_file)
                print(f'evaluating on {i}')
            
            if args.shigure_merge and len(args.branches)>1:
                ckpt = model_name_to_ckpt(models_and_datas[args.branches[i]]['model'][0])
                answer, log = test_glue(ckpt, data, args.device, smaller_batch=args.smaller_batch, model=model, tokenizer=tokenizer)
                log = str(overhead)
            elif args.prefer_merge and len(args.branches)>1:
                ckpt = model_name_to_ckpt(models_and_datas[args.branches[i]]['model'][0])
                model, overhead = solve_prefer_merge(ckpt, output_ckpt, ckpts[0], args)
                tokenizer = AutoTokenizer.from_pretrained(ckpts[0])
                answer, log = test_glue(ckpt, data, args.device, smaller_batch=args.smaller_batch, model=model, tokenizer=tokenizer)
                log = str(overhead)
            elif args.doppel_merge and len(args.branches)>1:
                ckpt = model_name_to_ckpt(models_and_datas[args.branches[i]]['model'][0])
                if args.doppel_linear and cached:
                    model = cache_model
                else:
                    model, overhead = solve_doppel_merge(ckpt, ckpts[1:], args.base_model, args)
                    cached = True
                    cache_model = model
                tokenizer = AutoTokenizer.from_pretrained(args.base_model)
                answer, log = test_glue(ckpt, data, args.device, smaller_batch=args.smaller_batch, model=model, tokenizer=tokenizer)
            elif args.twin_merge and len(args.branches)>1:
                ckpt = model_name_to_ckpt(models_and_datas[args.branches[i]]['model'][0])
                model, overhead = solve_twin_merge(ckpt, ckpts[1:], args.base_model, args)
                tokenizer = AutoTokenizer.from_pretrained(args.base_model)
                answer, log = test_glue(ckpt, data, args.device, smaller_batch=args.smaller_batch, model=model, tokenizer=tokenizer)
            elif 'LoRA' in output_ckpt:
                model_to_merge = PeftModel.from_pretrained(AutoModelForCausalLM.from_pretrained(args.base_model), output_ckpt)
                model = model_to_merge.merge_and_unload().to(args.device)
                tokenizer = AutoTokenizer.from_pretrained(args.base_model)
                answer, log = test_glue('Mistral', data, args.device, smaller_batch=args.smaller_batch, model=model, tokenizer=tokenizer)
                
            else:
                # assert(0)
                answer,log = test_glue(output_ckpt, data, args.device, smaller_batch=args.smaller_batch, tokenizer_ckpt=args.base_model)
            
            # if args.prefer_merge or args.doppel_merge:
            #     del model
            #     torch.cuda.empty_cache()
            
            os.makedirs(folder_file, exist_ok=True)
            with open(txt_file,"w+") as f:
                f.write(str(answer))
                f.write('\n')
                try:
                    f.write(log)
                except:
                    f.write('cant write logs!')
    if args.no_storage:
        print(f'erasing {args.erase_path}')
        shutil.rmtree(args.erase_path)
