from models_and_datas import get_models_and_datas
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch, os, json
from run import simple_model_name_to_ckpt
from peft import PeftModel
parameter_names = []

def dropper(cd):
    c = cd.reshape(-1).abs()
    threshold = c.kthvalue(int(c.shape[0]*0.8)).values
    return (cd) * (torch.abs(cd)>threshold)


def param_list(ckpt, args):
    torch.cuda.empty_cache()
    global parameter_names
    if 'LoRA' in ckpt:
        model_to_merge = PeftModel.from_pretrained(AutoModelForCausalLM.from_pretrained(args.base_model), ckpt)
        model = model_to_merge.merge_and_unload()
        
    else:
        model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.bfloat16)
    params = []
    if len(parameter_names)==0:
        for name, param in model.named_parameters():
            parameter_names += [name]
    for name, param in model.named_parameters():
        params += [param.detach()] 
    del model
    torch.cuda.empty_cache()
    return params

def task_vector(model, base_model, args, use_base_model=1, drop=0):
    params = []
    for i,j in zip(param_list(model, args),param_list(base_model, args)):
        if drop:
            params += [dropper(i-j*use_base_model)]
        else:
            params += [i-j*use_base_model]
    return params
   
def conflict_between(a,b,ta,tb):
    # def torch_svd(x):
    #     # return x
    #     try:
    #         U, S, Vh = torch.linalg.svd(x.float(), full_matrices=False)
    #         k = 20
    #         print(U.shape, S.shape, Vh.shape)
    #         return (U[:,:k] * S[:k].unsqueeze(0)) @ Vh[:k,:]
    #     except:
    #         return x
   

        
    arch = {
        'sign_conflict_num':0,
        'sign_conflict_magnitude':0,
        'diff_magnitude':0,
        'cosine_sim':0,
    }

    total = {
        'parameter_num':0,
        'magnitude':0,
        'tv_magnitude':0,
    }



    for ti,tj,i,j in zip(ta,tb,a,b):
        # si, sj = torch.linalg.svdvals(i), torch.linalg.svdvals(j)

        total['parameter_num'] += ti.reshape(-1).shape[0]
        total['magnitude'] += (abs(i)+abs(j)).sum().item()
        total['tv_magnitude'] += (abs(ti)+abs(tj)).sum().item()
        
        arch['sign_conflict_num'] += ((ti>=0)*(tj<0)).sum().item()
        arch['sign_conflict_magnitude'] += ((abs(ti)+abs(tj))[((ti>=0)*(tj<0))]).sum().item()
        arch['diff_magnitude'] += (abs(ti-tj)).sum().item()
        arch['cosine_sim'] += ((i*j).sum() / (i*i).sum()**0.5 / (j*j).sum()**0.5).item()
        

        
    return arch, total

if __name__ == '__main__':
    import argparse
    import numpy as np
    parser = argparse.ArgumentParser()
    parser.add_argument("--models_and_datas", type=str, default="models_and_datas_full_qwen_3b")    
    parser.add_argument('--branches', nargs='+')
    parser.add_argument("--method", type=str, default="ties")
    parser.add_argument("--output_folder", type=str, default="output/test0123")
    parser.add_argument("--device", type=str, default='cpu')
    parser.add_argument("--base_model", type=str, default="meta-llama/Meta-Llama-3-8B")
    parser.add_argument("--no_eval", action='store_true', default=False)
    parser.add_argument("--sequential", action='store_true', default=False)
    parser.add_argument("--myllama", action='store_true', default=True)
    parser.add_argument("--run_a_on_b", action='store_true', default=False)
    parser.add_argument("--base_test", action='store_true', default=False)
    parser.add_argument("--smaller_batch", type=int, default=1)
    parser.add_argument("--checkpoints2",action='store_true', default=False)
    parser.add_argument("--prefer_merge",action='store_true', default=False)
    parser.add_argument("--doppel_merge",action='store_true', default=False)
    parser.add_argument("--doppel_linear",action='store_true', default=False)
    parser.add_argument("--shigure_merge",action='store_true', default=False)
    parser.add_argument("--count", action='store_true', default=False)
    parser.add_argument("--svd_rank", type=int, default=-1)
    parser.add_argument("--twin_merge",action='store_true', default=False)
    parser.add_argument("--no_storage", action='store_true', default=False)
    parser.add_argument("--density", type=float, default=np.nan)
    parser.add_argument("--steps_rank", type=int, default=1)
    parser.add_argument('--weight_list', type=str, default='')
    args = parser.parse_args()
    models_and_datas = get_models_and_datas(args.models_and_datas)

    branches = list(models_and_datas.keys())
    real_branches = []
    for i in branches:
        if i!='base':
            real_branches += [i]
    branches = real_branches

    base_model = models_and_datas['base']['model'][0]
    args.base_model = base_model

    models = []
    tvs = []
    tv20s = []
    ckpts = []
    for i in real_branches:
        model = simple_model_name_to_ckpt(models_and_datas[i]['model'][0])
        ckpts += [model]
        models += [task_vector(model, base_model, args, use_base_model=0)]
        tvs += [task_vector(model, base_model, args, use_base_model=1)]
        tv20s += [task_vector(model, base_model, args, use_base_model=1,drop=1)]
        

    os.makedirs(f'output/0516conflict{args.models_and_datas}/',exist_ok=True)
    for i in range(0,len(real_branches)):
        for j in range(i+1,len(real_branches)):
            arch, total = conflict_between(models[i],models[j],tvs[i],tvs[j])
            arch20, total20 = conflict_between(models[i],models[j],tv20s[i],tv20s[j])
            data = {
                'arch':arch,
                'arch20':arch20,
                'total':total,
                'total20':total20,
            }
            with open(f'output/0516conflict{args.models_and_datas}/{real_branches[i]}-{real_branches[j]}.json','w+') as f:
                json.dump(data,f)

    from prefer_merge import solve_doppel_merge, aggregate
    args.doppel_merge = True
    args.doppel_linear = True
    for m in ['linear','ties','task_arithmetic']:
        args.method = m
        args.count=False

        model, overhead = solve_doppel_merge(ckpts[0], ckpts, args.base_model, args)
        m_model = []
        for name, param in model.named_parameters():
            m_model += [param]

        m_tv = aggregate(m_model, param_list(base_model, args), lambda x,y:x-y)
        for i in range(0,len(real_branches)):
            arch, total = conflict_between(models[i],m_model,tvs[i],m_tv)
            data = {
                'arch':arch,
                'total':total,
            }
            with open(f'output/0516conflict{args.models_and_datas}/{real_branches[i]}-{m}.json','w+') as f:
                json.dump(data,f)
