# from models_and_datas import models_and_datas
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch, os, json
from run import simple_model_name_to_ckpt

parameter_names = []

def model_name_to_ckpt(model_name):
        path = 'checkpoints/hub/models--'+model_name.replace('/','--')
        path += '/snapshots'
        if 'my-' in model_name:
            path = 'checkpoints/hub/'+model_name.replace('/','--')
            path += '/snapshots'
            path = '../llama_on_glue/' + path
            paths = os.listdir(path)
            path = path+'/'+paths[0]
            paths = os.listdir(path)

            m_paths = []
            for i in paths:
                i = i.replace('checkpoint-','')
                m_paths += [int(i)]
                
            path = path+'/checkpoint-'+str(sorted(m_paths)[-1])
            print(path)

        return path



def param_list(ckpt, args):
    torch.cuda.empty_cache()
    global parameter_names
    if 'LoRA' in ckpt:
        model_to_merge = PeftModel.from_pretrained(AutoModelForCausalLM.from_pretrained(args.base_model), ckpt)
        model = model_to_merge.merge_and_unload()
        
    else:
        model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.bfloat16)
    params = []
    if len(parameter_names)==0:
        for name, param in model.named_parameters():
            parameter_names += [name]
    for name, param in model.named_parameters():
        params += [param.detach()] 
    del model
    torch.cuda.empty_cache()
    return params

def task_vector(model, base_model, args, use_base_model=1):
    params = []
    for i,j in zip(param_list(model, args),param_list(base_model, args)):
        params += [i-j*use_base_model]
    return params
   
def conflict_between(model_a, model_b, base_model):
    def torch_svd(x):
        # return x
        try:
            U, S, Vh = torch.linalg.svd(x.float(), full_matrices=False)
            k = 20
            print(U.shape, S.shape, Vh.shape)
            return (U[:,:k] * S[:k].unsqueeze(0)) @ Vh[:k,:]
        except:
            return x
        # n, _ = X.shape
        # # center points along axes
        # if center:
        #     X = X - X.mean(dim=0)
        # # perform singular value decomposition
        # u, s, v = torch.svd(X)
        # # extract components
        # components = v.T
        # explained_variance = torch.mul(s, s) / (n-1)
        # return components, explained_variance
    model_a, model_b, base_model = simple_model_name_to_ckpt(model_a), \
        simple_model_name_to_ckpt(model_b), simple_model_name_to_ckpt(base_model)
    ta, tb = task_vector(model_a, base_model, None, use_base_model=0), task_vector(model_b, base_model, None, use_base_model=0)
    stat = [f'self_attn.{ch}_proj.weight' for ch in 'qkv'] +\
        [f'self_attn.{ch}_proj.bias' for ch in 'qkv'] +\
        ['input_layernorm.weight','post_attention_layernorm.weight'] +\
        [f'.{i}.' for i in range(0,36)] +\
        ['mlp','layernorm.weight','embed_tokens.weight']
    # stat = parameter_names
        
    arch = {}
    cnt = {}
    for i in stat:
        arch[i] =  {
            'conflict0':0,
            'conflict1':0,
            # 'sign_conflict':0,
            # 'sign_conflict_0.0001':0,
            # 'sign_conflict_0.001':0,
            # 'differ':0,
            # 'differ_0.0001':0,
            # 'differ_0.001':0,
            # 'sums':0,
            # 'not0':0,
        }
        cnt [i] = {
            'cnt0':0,
            'cnt1':0,
        }
    
    for i,j,name in zip(ta,tb,parameter_names):
        # si, sj = torch.linalg.svdvals(i), torch.linalg.svdvals(j)
        si, sj = torch_svd(i), torch_svd(j)
        m = abs(si)+abs(sj)
        for w in arch.keys():
            if w in name:
                if len(list(i.shape))>1:
                    arch[w]['conflict0'] += sum(torch.nn.CosineSimilarity(dim=0, eps=1e-9)(i,j)).item() 
                    cnt[w]['cnt0'] +=i.shape[1]
                    arch[w]['conflict1'] += sum(torch.nn.CosineSimilarity(dim=1, eps=1e-9)(i,j)).item()
                    cnt[w]['cnt1'] +=i.shape[0]
                # arch[w]['sums'] += m.sum().item()
                # arch[w]['not0'] += (si!=0).sum().item()+(sj!=0).sum().item()

                # arch[w]['sign_conflict'] += ((m)[si*sj<0]).sum().item()
                # arch[w]['differ'] += (abs(si-sj)).sum().item()
                # for u in [0.001,0.0001]:
                #     arch[w][f'sign_conflict_{u}'] += ((si*sj<0) * (m>u)).sum().item()
                #     arch[w][f'differ_{u}'] += (abs(si-sj)>u).sum().item()
    
    for w in stat:
        arch[w]['conflict0'] /= (cnt[w]['cnt0']+1e-6)
        arch[w]['conflict1'] /= (cnt[w]['cnt1']+1e-6)

        
    return arch
    # for i in quant.keys():
    #     print(f'{i} type: {quant[i].float()}')
    
# if __name__ == '__main__':
#     base_ckpt = models_and_datas['base']['model'][0]
    
  
    
#     cc = '../llama_on_glue/checkpoints/merged/ties__Qwentruct+tsst2+tcola+ctrte+tqnli+tmrpc+twnli+ctqqp+tmnli--d=0.2'
#     keys = list(models_and_datas.keys())[1:]
#     answer = {}
#     for i in range(len(keys)):
#         answer[keys[i]] = \
#             conflict_between(model_name_to_ckpt(models_and_datas[keys[i]]['model'][0]),cc,base_ckpt)
#     json_object = json.dumps(answer, indent=4)
#     with open("conflict.json", "w+") as outfile:
#         outfile.write(json_object)
#     # for i in range(len(keys)):
#     #     for j in range(i+1,len(keys)):
#     #         print (f'{keys[i]} conflict with {keys[j]}:')
#     #         conflict_between(model_name_to_ckpt(models_and_datas[keys[i]]['model'][0]),model_name_to_ckpt(models_and_datas[keys[j]]['model'][0]),base_ckpt)