
# from mergekit.sparsify import  SparsificationMethod, sparsify
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
import torch, os, json, threading
from evaluations import test_glue,load_ckpt
import sparsify

parameter_names = []
def aggregate(va, vb, f):
    params = []
    for a,b in zip(va,vb):
        params += [f(a,b)]
    return params

def param_list(ckpt, args):
    torch.cuda.empty_cache()
    global parameter_names
    if 'LoRA' in ckpt:
        model_to_merge = PeftModel.from_pretrained(AutoModelForCausalLM.from_pretrained(args.base_model), ckpt)
        model = model_to_merge.merge_and_unload()
        
    else:
        model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.bfloat16)
    params = []
    if len(parameter_names)==0:
        for name, param in model.named_parameters():
            parameter_names += [name]
    for name, param in model.named_parameters():
        params += [param.detach()] 
    del model
    torch.cuda.empty_cache()
    return params

def task_vector(model, base_model, args, use_base_model=1):
    params = []
    for i,j in zip(param_list(model, args),param_list(base_model, args)):
        if use_base_model:
            if args.doppel_linear and args.method=='ties':
                c = i - j
                c = c.reshape(-1).abs()
                threshold = c.kthvalue(int(c.shape[0]*0.8)).values
                params += [(i-j) * (torch.abs(i-j)>threshold)]
            else:
                params += [i-j]
        else:
            params += [i]
    return params
        

def merge_params(va,vb,args):
    # f = lambda x : sparsify(x,density=0.2,method=SparsificationMethod.magnitude)
    # va, vb = map(f,va), map(f,vb)
    params = []
    overhead,total = 0,0
    i = 0
    for a,b in zip(va,vb):
        params += [(a * (a.sign()!=b.sign()) + b * (a.sign()==b.sign()))]
        overhead += (a.sign()!=b.sign()).sum().int()
        total += b.reshape(-1).shape[0]
        # overhead += (a>0).sum().int()
        print(parameter_names[i])
        i+= 1
        print((a>0).sum().int()/ b.reshape(-1).shape[0],(a==0).sum().int()/ b.reshape(-1).shape[0],(a<0).sum().int()/ b.reshape(-1).shape[0])
        print(abs(a).sum())
        for i in range(0):
            try:
                cc = a[:,i]
                print((cc>0).sum().int()/ cc.reshape(-1).shape[0],(cc==0).sum().int()/ cc.reshape(-1).shape[0],(cc<0).sum().int()/ cc.reshape(-1).shape[0])
            except:
                None
    return params, overhead / total

def doppel(params, v, vs, args):
    # overhead,total = 0,0
    for i in range(len(v)):
        a = v[i]
        sum,cnt = params[i]
        for vb in vs:
            b = vb[i]

            if args.count:
                # get simple conflict
                sum = sum + b 
        
                # janus count mask
                # get mask
                # sum = sum*3 + (b.sign()).to(torch.int32)+1
                # cnt = cnt*2 + (b>=0)
            elif args.doppel_linear:
                # linear
                if args.method == 'linear':
                    sum += b 
                    cnt += 1
                elif args.method == 'task_arithmetic':
                    sum += b*0.4
                    cnt += 1
                elif args.method == 'ties':
                    sum += b * ((a>=0)==(b>=0))
                    cnt += ((a>=0)==(b>=0))
            else:            
                # janus
                # sum += b * (a.sign()==b.sign())
                # cnt += (a.sign()==b.sign())


                # 0is+ janus
                sum += b * ((a>=0)==(b>=0))
                cnt += ((a>=0)==(b>=0))


                # weird janus
                # sum += (b - sum) * (b>sum)  *  (b>=0) * (a>=0) + (sum - b) * (b < sum) * (b<0) * (a<0)
        params[i] = (sum,cnt)
    return params,0

def shigure(v, vs, args):
    params = []
    for i in range(len(vs[0])):
        if len(list(vs[0][i].shape))>1:
            direction = sum([b[i] / (b[i].pow(2).sum(dim=1,keepdim=True).sqrt()) for b in vs])
            params+= [ direction / (direction.pow(2).sum(dim=1,keepdim=True).sqrt())  *\
                torch.max(torch.stack([b[i].pow(2).sum(dim=1,keepdim=True).sqrt() for b in vs],dim=0),dim=0)[0]]
                # ([(b[i].pow(2).sum(dim=1,keepdim=True).sqrt()) for b in vs]) / len(vs)]
            # (sum([b[i].pow(2) for b in vs]).sum(dim=1,keepdim=True) / len(vs)).sqrt() ]
        else:
            params += [sum([b[i] for b in vs]) / len(vs)]

    return params,0


def apply_model(v, ckpt, args, use_base_model=1):
    model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.bfloat16)
    pretrained_dict = model.state_dict()
    for (name, param), vv in zip(model.named_parameters(),v):
       pretrained_dict[name] = vv + param*use_base_model
    model.load_state_dict(pretrained_dict)
    model = model.to(args.device)
    return model

def solve_doppel_merge(ckpt, ckpts, base_model, args):
    # put ckpt in ckpts, please
    f = lambda x: task_vector(x,base_model,args)
    v = f(ckpt)

    params = []
    for i in range(len(v)):
        if args.count:
            params += [(torch.zeros_like(v[i],dtype=torch.int32),torch.zeros_like(v[i],dtype=torch.int32))]
        else:
            params += [(v[i]*0,v[i]*0)]
    
    if args.svd_rank > 0:
        def torch_svd(x, k):
            if(len(x.shape)>1):
                U, S, Vh = torch.linalg.svd(x.float(), full_matrices=False)
                # print(U.shape, S.shape, Vh.shape)
                return (U[:,:k] * S[:k].unsqueeze(0)) @ Vh[:k,:]
            else:
                return x
        for i in range(len(v)):
            v[i] = torch_svd(v[i].to(args.device).sign(),args.svd_rank)
            v[i] = (v[i]>0.5).to(torch.int32) - (v[i]<-0.5).to(torch.int32)
            v[i] = v[i].to('cpu')
    
    if args.doppel_linear and args.method == 'ties':
        args.method = 'linear'
        sums = []
        for i in range(len(v)):
            sums += [(v[i]*0)]
        for ckpti in ckpts:
            sums = aggregate(sums, f(ckpti), lambda x,y:x+y)
        args.method = 'ties'
        for ckpti in ckpts:
            params, overhead = doppel(params, sums, [f(ckpti)], args)
    else:
        for ckpti in ckpts:
            params, overhead = doppel(params, v, [f(ckpti)], args)

    if 'count mas kfor janus' and args.count and 0:
        # get simple conflict
        count, total = 0, 0
        for i in range(len(v)):
            vi = v[i]
            pi = params[i][0]
            count += ((vi>=0) != (pi>=0)).sum().item()
            total += vi.reshape(-1).shape[0]
        os.makedirs('output/increasing/', exist_ok=True)
        with open(f'output/increasing/{len(ckpts)}.txt','a+') as f:
            f.write(f'conflict for {ckpt} : {count / total} \n')
        
        # get mask
        # def trans_k(decimal,k):
        #     s = ""
        #     for i in range(8):
        #         s+=str(decimal%k)
        #         decimal//=k
        #     return s
        # for i,_ in params:
        #     print(i.dtype, _.dtype)
        # # 
        
        # data = {}
        # saver = []
        # for i,_ in params:
        #     saver += [i]
        # torch.save(saver, '3wayglue.pt')
        # assert 0,'save mas kfor janus'

        # for mask in range(6561):
        #     sum = 0
        #     for i,_ in params:
        #         sum += torch.sum(i==(mask)).item()
        #     data[trans_k(mask,3)] = sum
        # with open('cg3b333.json',"w+")  as f:
        #     json.dump(data,f)
        # assert 0,'count mas kfor janus'
    for _,i in params:
        if not (i>0).all():
            print('warning for 0 in cnt!!!')
    params = aggregate(params, params, lambda x,y:x[0]/(x[1]+0.000001)*(x[1]>0))

    # s2 = lambda x:aggregate(x,x,lambda x,y:x**2)
    # params = aggregate(params, v, lambda x,y:x[0]/(torch.sum(x[0]))*(torch.sum(y)))

    for i in v:
        del i
    torch.cuda.empty_cache()

    return apply_model(params,base_model,args), overhead


def solve_twin_merge(ckpt, ckpts, base_model, args):
    # put ckpt in ckpts, please
    f = lambda x: task_vector(x,base_model,args)
    v = f(ckpt)

    params = []
    for i in range(len(v)):
        params += [v[i]*0]

    # task arithmetic for shared model
    for ckpti in ckpts:
        params = aggregate(params, f(ckpti), lambda x,y:x+y)
    params = aggregate(params, params, lambda x,y:x/len(ckpts)*0.3)
    def extract_twin_vector(
        lora, 
        merged,
        new_rank,
    ):
        twin_vector = sparsify.svd(
                (lora - merged).to(torch.float32), 
                density=0.9, # useless
                new_rank=new_rank,
            )
        return twin_vector

    params = aggregate(v, params, lambda x,y:extract_twin_vector(x,y,new_rank=8))
    for i in v:
        del i
    torch.cuda.empty_cache()

    return apply_model(params,base_model,args), 0



def solve_shigure_merge(ckpt, ckpts, base_model, args):
    # put ckpt in ckpts, please
    f = lambda x: task_vector(x,base_model,args, use_base_model=0)
    ckpts = list(map(f, ckpts))


    params, overhead = shigure(None, ckpts , args)
    return apply_model(params,base_model,args,use_base_model=0), overhead
    

def solve_prefer_merge(ckpt, merged_ckpt, base_model, args):
    # first to be self, second to be merged
    # g = lambda x:model_name_to_ckpt(models_and_datas[x]['model'][0])
    # model will be on args.device!
    va,vb = task_vector(ckpt,base_model,args),task_vector(merged_ckpt,base_model,args)
    
    v, overhead = merge_params(va,vb, args)
    return apply_model(v,base_model,args), overhead
    
    # os.makedirs('output/prefer',exist_ok=True)
    # with open(f'output/prefer/{'+'.join(branches)}.txt','w+') as f:
    #     for i in branches:
    #         answer,log = test_glue(g(branches[0]), i, device,model=model,tokenizer=tokenizer)
    #         f.write(f'metric for {i}:    {answer}\n')

# if __name__ == '__main__':
#     base_model = "Qwen/Qwen2.5-3B-Instruct"

#     branches = [['wnli','rte'],['wnli','cola'],['qqp','cola'],['rte','sst2']]
#     branches += [['wnli','qqp'],['qqp','rte'],['qqp','sst2'],['qqp','rte']]

#     for b in branches:
#         get_running(solve, (base_model,b))
    
#     for k in range(gpus): 
#         t = threads[k]
#     if t is not None and t.is_alive():
#         t.join()
    