
# from mergekit.sparsify import  SparsificationMethod, sparsify
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch, os, json, threading
from evaluations import test_glue,load_ckpt
import argparse
parameter_names = []

def param_list(ckpt, args):
    torch.cuda.empty_cache()
    global parameter_names
    model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.bfloat16)
    params = []
    if len(parameter_names)==0:
        for name, param in model.named_parameters():
            parameter_names += [name]
    for name, param in model.named_parameters():
        params += [param.detach().to(args.device)] 
    del model
    torch.cuda.empty_cache()
    return params

def task_vector(model, base_model, args, biased=True):
    params = []
    for i,j in zip(param_list(model, args),param_list(base_model, args)):
        if biased:
            params += [i-j]
        else:
            params += [i]
    return params
        

def merge_params(va,vb,args):
    # f = lambda x : sparsify(x,density=0.2,method=SparsificationMethod.magnitude)
    # va, vb = map(f,va), map(f,vb)
    params = []
    overhead,total = 0,0
    for a,b in zip(va,vb):
        params += [(a * (a.sign()!=b.sign()) + b * (a.sign()==b.sign()))]
        overhead += (a.sign()!=b.sign()).sum().int()
        total += b.reshape(-1).shape[0]
    return params, overhead / total

def doppel(params, v, vs, args):
    # overhead,total = 0,0
    for i in range(len(v)):
        a = v[i]
        sum = params[i]
        for vb in vs:
            b = vb[i]
            # sum += b * (a.sign()==b.sign())
            # cnt += (a.sign()==b.sign())
            sum += (b - sum) * (b>sum)  *  (b>=0) * (a>=0) + (sum - b) * (b < sum) * (b<0) * (a<0)
        params[i] = sum
    return params,0
         
def apply_model(v, ckpt, args):
    model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.bfloat16).to(args.device)
    pretrained_dict = model.state_dict()
    for (name, param), vv in zip(model.named_parameters(),v):
       pretrained_dict[name] = vv + param
    model.load_state_dict(pretrained_dict)
    model = model
    return model

def aggregate(va, vb, f):
    params = []
    for a,b in zip(va,vb):
        params += [f(a,b)]
    return params


def solve_drift(ckpt1, ckpt2, ft1, ft2, ckpt3, base_model, args):
    # put ckpt in ckpts, please
    f = lambda x: task_vector(x,base_model,args)
    g = lambda x: task_vector(x,base_model,args,biased=False)
    substract = lambda a,b: aggregate(a,b,lambda x,y:x-y)
    b = substract(substract(g(ft1),g(ckpt1)),substract(g(ft2),g(ckpt2)))
    a = aggregate(g(ckpt2),g(ckpt1),lambda x,y:x-y)
    c = aggregate(g(ckpt3),g(ckpt1),lambda x,y:x-y)
    eps = 1e-9
    a = aggregate(a,a,lambda x,y:x+eps)
    b = aggregate(b,b,lambda x,y:x+eps)
    c = aggregate(c,c,lambda x,y:x+eps)
    v = aggregate(b,aggregate(c,a,lambda x,y:x/y),lambda x,y:x*y)
    v = aggregate(f(ckpt3),v, lambda x,y:x+y)
    return apply_model(v,base_model,args)

def solve_simple_drift(ckpt1, ckpt2, ft1, ft2, ckpt3, base_model, args):
    # put ckpt in ckpts, please
    f = lambda x: task_vector(x,base_model,args)
    g = lambda x: task_vector(x,base_model,args,biased=False)
    substract = lambda a,b: aggregate(a,b,lambda x,y:x-y)
    v = substract(g(ft1),g(ckpt1))
    v = aggregate(f(ckpt3),v, lambda x,y:x+y)
    return apply_model(v,base_model,args)


if __name__ == '__main__':
    def model_name_to_ckpt(model_name):
        path = 'checkpoints/hub/models--'+model_name.replace('/','--')
        path += '/snapshots'
        
        if 0:
            path = path.replace('checkpoints','checkpoints2')
        if 'my-' in model_name or 'experimental' in model_name:
            path = 'checkpoints/hub/'+model_name.replace('/','--')
            
            if 0:
                path = path.replace('checkpoints','checkpoints2')
            path += '/snapshots'
            path = '../llama_on_glue/' + path
            paths = os.listdir(path)
            path = path+'/'+paths[0]
            paths = os.listdir(path)

            m_paths = []
            for i in paths:
                i = i.replace('checkpoint-','')
                m_paths += [int(i)]
                
            path = path+'/checkpoint-'+str(sorted(m_paths)[-1])
            print(path)
        else:
            if 1:
                path = '../llama_on_glue/' + path
            paths = os.listdir(path)
            path = path+'/'+paths[0]
            print(path)
        
        return path

    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=str, required=True,default='cuda:7')
    args = parser.parse_args()

    ckpt1 = model_name_to_ckpt('Qwen--Qwen2.5-3B-Instruct')
    ckpt2 = model_name_to_ckpt('my-qwen3B--Qwen2.5-3B-Instructmnli')
    ckpt3 = model_name_to_ckpt('my-qwen3B--Qwen2.5-3B-Instructsst2')
    ft1 = model_name_to_ckpt('experimental--Qwen2.5-3B-Instructcola')
    ft2 = model_name_to_ckpt('experimental--Qwen2.5-3B-Instructmnlicola')
    ckpt3 = model_name_to_ckpt('my-qwen3B--Qwen2.5-3B-Instructsst2')
    ft3 = model_name_to_ckpt('experimental--Qwen2.5-3B-Instructsst2cola')
    base_model = model_name_to_ckpt('Qwen--Qwen2.5-3B-Instruct')

    model = solve_drift(ckpt1, ckpt2, ft1, ft2, ckpt3, base_model, args)
    tokenizer = AutoTokenizer.from_pretrained(base_model)


    ckpt = 'Qwen--Qwen2.5-3B-Instruct'
    data = 'cola'

    answer, log = test_glue(ckpt, data, args.device, smaller_batch=1, model=model, tokenizer=tokenizer)
    print(log[:500])
    print('drift answer:',answer)            

    model = solve_simple_drift(ckpt1, ckpt2, ft1, ft2, ckpt3, base_model, args)
    tokenizer = AutoTokenizer.from_pretrained(base_model)
    answer, log = test_glue(ckpt, data, args.device, smaller_batch=1, model=model, tokenizer=tokenizer)
    print(log[:500])
    print('simple drift answer:',answer)    

    model = AutoModelForCausalLM.from_pretrained(ckpt3, torch_dtype=torch.bfloat16).to(args.device)
    answer, log = test_glue(ckpt, data, args.device, smaller_batch=1, model=model, tokenizer=tokenizer)
    print(log[:500])
    print('base answer:',answer)            

    model = AutoModelForCausalLM.from_pretrained(ft3, torch_dtype=torch.bfloat16).to(args.device)
    answer, log = test_glue(ckpt, data, args.device, smaller_batch=1, model=model, tokenizer=tokenizer)
    print(log[:500])
    print('finetuned answer:',answer)            