import glob, os, yaml, subprocess
from evaluations import load_ckpt, test, test_glue
import argparse
from models_and_datas import get_models_and_datas
import numpy as np
import time, shutil, ast, torch
from evaluations_t5 import test_t5
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,AutoModelForSeq2SeqLM
from run import signator

def get_task_vector(params, base_params):
    ret = []
    for i,j in zip(params,base_params):
        ret += [i-j]
    return ret
eps = 1e-9

def slerp(base_tensor, tensors):
    weights = torch.ones(len(tensors)).to(base_tensor.device)
    tensors = torch.stack(tensors, dim=0)
    if base_tensor is not None:
        tensors -= base_tensor

    tensors_flat = tensors.view(tensors.shape[0], -1)

    if 1:
        weights = weights / weights.sum()

    # Project to unit hypersphere
    norms = torch.norm(tensors_flat, dim=-1, keepdim=True)
    unit_tensors = tensors_flat / (norms + eps)

    mean = (unit_tensors * weights.view(-1, 1)).sum(0)
    mean_norm = torch.norm(mean)

    mean = mean / mean_norm
    if tensors.shape[0] == 2:
        # fallback to linear interpolation
        res = (tensors[0] * weights[0] + tensors[1] * weights[1]).view(
            tensors.shape[1:]
        )
        if base_tensor is not None:
            res = res + base_tensor
        return res
    # Project to tangent space
    dots = (unit_tensors * mean).sum(-1, keepdim=True)
    tangent_vectors = unit_tensors - dots * mean

    # Interpolate
    tangent_result = (tangent_vectors * weights.view(-1, 1)).sum(0)

    # Project back to sphere using exponential map
    tangent_norm = torch.norm(tangent_result) + eps
    result = mean * torch.cos(tangent_norm) + tangent_result * (
        torch.sin(tangent_norm) / tangent_norm
    )

    avg_norm = (norms.squeeze(-1) * weights).sum()
    result = result * avg_norm
    result = result.view(tensors.shape[1:])

    if base_tensor is not None:
        result = result + base_tensor

    return result

def merge_W (w_list, method, args):
    # merge param matrixes with method
    # w_list should contain base model as w_list[0]
    # methods can use now:
        # _SUM     : lINEAR AVERAGE
        # _TA     : TASK ARITHMETIC
        # TIES_TIES: TIES-MERGING
        # DARE_SUM : DARE + LINEAR AVERAGE
        # DARE_TIES: DARE + TIES-MERGING
        
    w0 = torch.zeros_like(w_list[0]).to(w_list[0].device)
    m1,m2 = method.split('_')[0],method.split('_')[1]
    
    if m1 == 'DARE':
        P = 90
        for w in w_list[1:]:
            mask = (torch.randint(0,100,w.shape)>=P).to(w_list[0].device)
            w += -w + mask * (w-w_list[0]) / (1 - P/100) + w_list[0]
    elif m1 == 'TIES':
        top_k = 1 - 20 / 100
        w_l2 = w_list[1:]
        for j in range(len(w_l2)):
            w_l2[j] -= w_list[0]
        for param in w_l2:
            all_params = param.reshape(-1).abs()
            threshold = all_params.kthvalue(int(all_params.shape[0]*top_k)).values
            # print('threshold',threshold)
            mask = param.abs()>threshold
            param *= mask
        s = sum(w_l2)
        return (w_list[0] +  (s > 0) * sum([w * (w>0) for w in w_l2]) / (sum([(w>0) for w in w_l2])+eps) \
                +  (s < 0) * sum([w * (w<0) for w in w_l2]) / (sum([(w<0) for w in w_l2])+eps)  )   
    elif m1 == 'SLERP':
        return slerp(w_list[0],w_list[1:])
        
    if m2 == 'SUM':
        for w in w_list[1:]:
            w0 += w
        return w0 / len(w_list[1:])
    elif m2 == 'TA':
        for w in w_list[1:]:
            w0 += (w-w_list[0]) * 0.4
        return w_list[0] + w0 / len(w_list[1:])
    elif m2 == 'TIES':
        w_l2 = w_list[1:]
        for j in range(len(w_l2)):
            w_l2[j] -= w_list[0]
        s = sum(w_l2)
        return (w_list[0] + (s > 0) * sum([w * (w>0) for w in w_l2]) / (sum([(w>0) for w in w_l2])+eps) \
                + (s < 0) * sum([w * (w<0) for w in w_l2]) / (sum([(w<0) for w in w_l2])+eps)  ) 
    else:
        raise Exception('Method Invalid')


def merge_params(param_list, method, args):
    # take param_list[0] as the base model;
    ret = []
    for i in range(len(param_list[0])):
        ret += [merge_W([param[i].to(args.device) for param in param_list], method, args).to('cpu')]
    return ret

def merge_model(model_list, method, args):
    # model_list should contain base model as model_list[0]
    merged_param = merge_params([[param.detach() for name, param in model.named_parameters()] for model in model_list],\
        method, args)
    
    model = model_list[0]
    pretrained_dict = model.state_dict()
    id = 0
    for name, param in model.named_parameters():
        pretrained_dict[name] = merged_param[id]
        # param = merged_param[id]
        id += 1
    model.load_state_dict(pretrained_dict)
    return model


def acree_merge(list_of_ckpts, output_path, args):
    NotImplementedError


def model_name_to_ckpt(model_name):
    path = 'checkpoints/hub/models--'+model_name.replace('/','--')
    path += '/snapshots'
    
    if args.checkpoints2:
        path = path.replace('checkpoints','checkpoints2')
    if 'my-' in model_name:
        path = 'checkpoints/hub/'+model_name.replace('/','--')
        
        if args.checkpoints2:
            path = path.replace('checkpoints','checkpoints2')
        path += '/snapshots'
        path = '../llama_on_glue/' + path
        paths = os.listdir(path)
        path = path+'/'+paths[0]
        paths = os.listdir(path)

        m_paths = []
        for i in paths:
            i = i.replace('checkpoint-','')
            m_paths += [int(i)]
            
        path = path+'/checkpoint-'+str(sorted(m_paths)[-args.steps_rank])
        print(path)
    else:
        if args.myllama:
            path = '../llama_on_glue/' + path
        paths = os.listdir(path)
        path = path+'/'+paths[0]
        print(path)
    
    return path

def model_names_to_output(model_names, args):
    ID = model_names[0].split('/')[0]
    path = f'checkpoints/merged/{args.method}__' + ID + '+'.join(map(signator,model_names))
    if args.sequential:
        path = f'checkpoints/merged/sequential_{args.method}__' + ID + '0'.join(map(signator,model_names))
    
    if args.checkpoints2:
        path = path.replace('checkpoints','checkpoints2')
    
    if not np.isnan(args.density):
        path = path+f'--d={args.density}'
        
    if args.myllama:
        path = '../llama_on_glue/' + path
        
    if args.no_storage:
        sid = time.time()
        path = path.replace('merged',f'merged/{sid}')
        args.erase_path = path[:path.find('merged')]+f'merged/{sid}'
    
    os.makedirs(path, exist_ok=True)
    return path + '/'
    
    
if __name__ == '__main__':   
    os.makedirs('partial',exist_ok=True) 
    # file = open('temporate.yml', 'r', encoding="utf-8")
    # file_data = file.read()
    # data = yaml.load(file_data,Loader=yaml.FullLoader)
    # print(data)
    
    # python run.py --branches code --device cuda:0
    # python run.py --branches guard korean --device cuda:0 --method=dare_linear
    # python run.py --branches guard korean italian --device cpu --method=linear --sequential
    parser = argparse.ArgumentParser()
    parser.add_argument('--branches', nargs='+', required=True)
    parser.add_argument("--method", type=str, default="ties")
    parser.add_argument("--output_folder", type=str, default="output/test0123")
    parser.add_argument("--device", type=str, required=True,default='cpu')
    parser.add_argument("--base_model", type=str, default="meta-llama/Meta-Llama-3-8B")
    parser.add_argument("--no_eval", action='store_true', default=False)
    parser.add_argument("--sequential", action='store_true', default=False)
    parser.add_argument("--myllama", action='store_true', default=True)
    parser.add_argument("--run_a_on_b", action='store_true', default=False)
    parser.add_argument("--base_test", action='store_true', default=False)
    parser.add_argument("--smaller_batch", type=int, default=1)
    parser.add_argument("--models_and_datas", type=str, required=True)
    parser.add_argument("--checkpoints2",action='store_true', default=False)
    
    
    
    parser.add_argument("--no_storage", action='store_true', default=False)
    
    
    parser.add_argument("--density", type=float, default=np.nan)
    parser.add_argument("--steps_rank", type=int, default=1)
    parser.add_argument('--weight_list', type=str, default='')
    args = parser.parse_args()
    
    models_and_datas = get_models_and_datas(args.models_and_datas)
    if len(args.weight_list):
        args.weight_list = ast.literal_eval(args.weight_list)
    else:
        args.weight_list = [np.nan for i in args.branches]
    if len(args.weight_list)!=len(args.branches):
        assert 0, 'Number of Branches not equal to Number of Weights!'
        
    model_names = []
    datasets = []
    print(args.branches)
    
    for i in args.branches:
        model_names += models_and_datas[i]['model']
        datasets += models_and_datas[i]['datasets']
    if len(args.branches)!=len(model_names):
        assert 0, 'Some branches not recognized!'
        
    if args.base_test:
        assert 0
        model_names = models_and_datas['base']['model']
        
    if len(model_names) == 1:
        # no need to merge
        output_ckpt = model_name_to_ckpt(model_names[0])
        model = AutoModelForSeq2SeqLM.from_pretrained(output_ckpt).to(args.device)
    elif args.run_a_on_b:
        # test first model on dataset b
        output_ckpt = model_name_to_ckpt(model_names[0])
        model = AutoModelForSeq2SeqLM.from_pretrained(output_ckpt).to(args.device)
        datasets = datasets[1:]
    # elif args.sequential:
    #     merged_models = model_names[:2]
    #     output_ckpt = model_names_to_output(merged_models, args)
    #     model_names_partial = [args.base_model] + merged_models
    #     ckpts = list(map(model_name_to_ckpt, model_names_partial))
    #     acree_merge(ckpts, output_ckpt, args)
        
    #     base_ckpt = model_name_to_ckpt(args.base_model)
    #     for i in model_names[2:]:
    #         merged_models += [i]
    #         output_ckpt2 = model_names_to_output(merged_models, args)
    #         ckpts = [base_ckpt, output_ckpt, model_name_to_ckpt(i)]
    #         print(ckpts)
    #         acree_merge(ckpts, output_ckpt2, args)
    #         output_ckpt = output_ckpt2
    else:    
        # list_of_ckpts's first ckpt is base ckpt!
        model_names = [args.base_model] + model_names
        ckpts = list(map(model_name_to_ckpt, model_names))
        ckpts[0] = args.base_model
        # output_ckpt = model_names_to_output(model_names, args)
        
        if args.no_storage:
            model_names = [model_name.replace('/','--') for model_name in model_names]
            need_to_run = 0
            for data in datasets:
                folder_file = f'{args.output_folder}/{args.method}/{"-".join(map(signator,model_names))}'
                if not np.isnan(args.weight_list[0]):
                    folder_file = folder_file.replace(f'{args.method}',f'{args.method}--weight_list={args.weight_list}')
                if not np.isnan(args.density):
                    folder_file = folder_file.replace(f'{args.method}',f'{args.method}--density={args.density}')
                if not np.isnan(args.steps_rank):
                    folder_file = folder_file.replace(f'{args.method}',f'{args.method}--steps_rank={args.steps_rank}')
                txt_file = folder_file +f'/{data}.txt'
                if os.path.exists(txt_file):
                    print(f'aleardy evaled on {data}!')
                else:
                    need_to_run = 1
            if not need_to_run:
                exit(0)
        
        # acree_merge(ckpts, output_ckpt, args)
            
        transfer = {
            'ties':'TIES_',
            'dare_ties':'DARE_TIES',
            'linear':'_SUM',
            'task_arithmetic':'_TA',
            'slerp':'SLERP_',
            'multislerp':'SLERP_'
        }
        
        model_list = []
        for ckpt in ckpts:
            print(ckpt)
            tmp_m = AutoModelForSeq2SeqLM.from_pretrained(ckpt)
            tmp_m.eval()
            model_list += [tmp_m]
            
        model = merge_model(model_list, transfer[args.method], args).to(args.device)
        
    model_names = [model_name.replace('/','--') for model_name in model_names]
    
    for data in datasets:
        # if args.spj and data!='humaneval_tiny' and data!='gsm8k':
        #     continue
        folder_file = f'{args.output_folder}/{args.method}/{"-".join(map(signator,model_names))}'
        if not np.isnan(args.weight_list[0]):
            folder_file = folder_file.replace(f'{args.method}',f'{args.method}--weight_list={args.weight_list}')
        if not np.isnan(args.density):
            folder_file = folder_file.replace(f'{args.method}',f'{args.method}--density={args.density}')
        if not np.isnan(args.steps_rank):
            folder_file = folder_file.replace(f'{args.method}',f'{args.method}--steps_rank={args.steps_rank}')
        
        txt_file = folder_file +f'/{data}.txt'
        if os.path.exists(txt_file):
            print(f'aleardy evaled on {data}!')
            continue
        else:
            print(txt_file)
            print(f'evaluating on {data}')
                    
        answer,log = test_t5('', data, args.device, smaller_batch=args.smaller_batch, model=model)
        
        
        os.makedirs(folder_file, exist_ok=True)
        with open(txt_file,"w+") as f:
            f.write(str(answer))
            f.write('\n')
            try:
                f.write(log)
            except:
                f.write('cant write logs!')
    
    
    exit(0)
    if not args.no_eval:
        # if not only merge is required
        
        model_names = [model_name.replace('/','--') for model_name in model_names]
        
        for data in datasets:
            # if args.spj and data!='humaneval_tiny' and data!='gsm8k':
            #     continue
            folder_file = f'{args.output_folder}/{args.method}/{"-".join(map(signator,model_names))}'
            if not np.isnan(args.weight_list[0]):
                folder_file = folder_file.replace(f'{args.method}',f'{args.method}--weight_list={args.weight_list}')
            if not np.isnan(args.density):
                folder_file = folder_file.replace(f'{args.method}',f'{args.method}--density={args.density}')
            if not np.isnan(args.steps_rank):
                folder_file = folder_file.replace(f'{args.method}',f'{args.method}--steps_rank={args.steps_rank}')
            
            txt_file = folder_file +f'/{data}.txt'
            if os.path.exists(txt_file):
                print(f'aleardy evaled on {data}!')
                continue
            else:
                print(txt_file)
                print(f'evaluating on {data}')
                        
            answer,log = test_glue(output_ckpt, data, args.device, smaller_batch=args.smaller_batch)
            
            
            os.makedirs(folder_file, exist_ok=True)
            with open(txt_file,"w+") as f:
                f.write(str(answer))
                f.write('\n')
                try:
                    f.write(log)
                except:
                    f.write('cant write logs!')
    if args.no_storage:
        print(f'erasing {args.erase_path}')
        shutil.rmtree(args.erase_path)
