import socket
import pickle
import os
import threading
from vllm_inject.utils import *
import torch
import torch.nn.functional as F
import copy
import numpy as np
from torch.autograd import Variable

def num_curr_seqs_task(port):
    server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    host = socket.gethostname()
    port = port
    while 1:
        try:
            server_socket.bind((host, port))
            break
        except:
            continue
    server_socket.listen(3)
    print("num_curr_seqs Socket server is listening...")
    # import pdb;pdb.set_trace()
    model_number_dict = {}
    run_dict = {}
    while True:
        # 建立客户端连接
        client_sockets = []
        for _ in range(3):
            client_socket, addr = server_socket.accept()
            client_sockets.append(client_socket)
        num_curr_seqs = 999999999
        for i, client_socket in enumerate(client_sockets):
            # logits_data = client_socket.recv(4096)
            # logits = pickle.loads(logits_data)
            num_curr_seqs_data = client_socket.recv(4096)
            model_name, num_curr_seqs_now, run_now = pickle.loads(num_curr_seqs_data)
            if model_name not in model_number_dict:
                model_number_dict[model_name] = 0
                run_dict[model_name] = 0
            model_number_dict[model_name] += 1
            run_dict[model_name] += run_now
            # print(model_name)
            num_curr_seqs = min(num_curr_seqs_now, num_curr_seqs)
        # print(model_number_dict)
        # print(run_dict)
        for i, client_socket in enumerate(client_sockets):
            result_data = pickle.dumps(num_curr_seqs)
            client_socket.send(result_data)
            client_socket.close()
# fo = open("/xx/analysis/alpha_truthfulqa.txt", "a")
# fo.write("hello\n")
import sys
def logits_task(port, model_str_list, alpha, upa, downa):
    server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    host = socket.gethostname()
    port = port
    while 1:
        try:
            server_socket.bind((host, port))
            break
        except:
            continue
    server_socket.listen(len(model_str_list))
    print("logits Socket server is listening...")
    # import pdb;pdb.set_trace()
    model_number_dict = {}
    model_dict_id = {}
    for idx, k in enumerate(model_str_list):
        model_dict_id[k] = idx
    #/xx/alt/models/llama2-gsm-7b
    #meta-llama/Llama-2-7b-hf
    while True:
        # 建立客户端连接
        client_sockets = []
        for _ in range(3):
            client_socket, addr = server_socket.accept()
            client_sockets.append(client_socket)
        logits_all = [None, None, None]
        # import pdb;pdb.set_trace()
        for i, client_socket in enumerate(client_sockets):
            # logits_data = client_socket.recv(4096)
            # logits = pickle.loads(logits_data)
            model_name, logits_now = receive_large_tensor(client_socket)
            if model_name not in model_number_dict:
                model_number_dict[model_name] = 0
            model_number_dict[model_name] += 1
            # print(model_name)
            logits_all[model_dict_id[model_name]] = logits_now
        # print(model_number_dict)
        # import pdb;pdb.set_trace()
        def verify(l_expert, l, s_expert, s, eps=1e-20):
            # import pdb;pdb.set_trace()
            s_expert_p = F.softmax(s_expert, dim=1)
            l_expert_p = F.softmax(l_expert, dim=1)
            s_p = F.softmax(s, dim=1)
            l_p = F.softmax(l, dim=1)
            l_gap = l_expert_p * torch.log((l_expert_p + eps) / (l_p + eps))
            s_gap = s_expert_p * torch.log((s_expert_p + eps) / (s_p + eps))
            # print('l expert like l', ((l_gap - s_gap) ** 2).sum())
            l0 = ((l_gap - s_gap) ** 2).sum()
            l_gap = l_p * torch.log((l_p + eps) / (l_expert_p + eps))
            s_gap = s_p * torch.log((s_p + eps) / (s_expert_p + eps))
            # print('l like l expert', ((l_gap - s_gap) ** 2).sum())
            l1 = ((l_gap - s_gap) ** 2).sum()
            return l0, l1 
        
        def verify2(l_expert, l, s_expert, s, eps=1e-20):
            # import pdb;pdb.set_trace()
            s_expert_p = F.softmax(s_expert, dim=1)
            l_expert_p = F.softmax(l_expert, dim=1)
            s_p = F.softmax(s, dim=1)
            l_p = F.softmax(l, dim=1)
            l_gap = l_expert_p * torch.log((l_expert_p + eps) / (l_p + eps)) + l_p * torch.log((l_p + eps) / (l_expert_p + eps))
            s_gap = s_expert_p * torch.log((s_expert_p + eps) / (s_p + eps)) + s_p * torch.log((s_p + eps) / (s_expert_p + eps))
            # print('l expert like l', ((l_gap - s_gap) ** 2).sum())
            l0 = ((l_gap - s_gap) ** 2).sum()
            return l0, 0
        
        def verify3(l_expert, l, s_expert, s, eps=1e-20):
            # import pdb;pdb.set_trace()
            s_expert_p = F.softmax(s_expert, dim=1)
            l_expert_p = F.softmax(l_expert, dim=1)
            s_p = F.softmax(s, dim=1)
            l_p = F.softmax(l, dim=1)
            l_mean = (l_p+l_expert_p) / 2
            s_mean = (s_p+s_expert_p) / 2
            l_gap = l_expert_p * torch.log((l_expert_p + eps) / (l_mean + eps)) / 2 + l_p * torch.log((l_p + eps) / (l_mean + eps)) / 2
            s_gap = s_expert_p * torch.log((s_expert_p + eps) / (s_mean + eps)) / 2 + s_p * torch.log((s_p + eps) / (s_mean + eps)) / 2
            # print('l expert like l', ((l_gap - s_gap) ** 2).sum())
            l0 = ((l_gap - s_gap) ** 2).sum()
            return l0, 0
        
        def verify4(l_expert, l, s_expert, s, i, eps=1e-20):
            # import pdb;pdb.set_trace()
            s_expert_p = F.softmax(s_expert, dim=1)
            l_expert_p = F.softmax(l_expert, dim=1)
            s_p = F.softmax(s, dim=1)
            l_p = F.softmax(l, dim=1)
            ss_p = F.softmax(s_expert - s, dim=1)
            ll_gap = l_expert_p * torch.log((l_expert_p + eps) / (l_p + eps)) + l_p * torch.log((l_p + eps) / (l_expert_p + eps))
            lss_gap = 30*l_expert_p * torch.log((l_expert_p + eps) / (ss_p + eps)) + 30*ss_p * torch.log((ss_p + eps) / (l_expert_p + eps))
            # print('l expert like l', ((l_gap - s_gap) ** 2).sum())
            l0 = (ll_gap + lss_gap).sum() #(ll_gap ** 2).sum() + (lss_gap ** 2).sum()
            return l0, 0
        
        def calc(s_expert, s, l, eps=1e-20):
            s_expert_p = F.softmax(s_expert, dim=1)
            s_p = F.softmax(s, dim=1)
            l_p = F.softmax(l, dim=1)
            a = torch.log(l_p + eps)
            b = s_expert_p * torch.log((s_expert_p + eps) / (s_p + eps))
            x0 = torch.exp(a - 1) + eps
            lmin, lminx = 999999, 0
            for i in range(100):
                # _x0 = copy.deepcopy(x0)
                zi = x0 * torch.log(x0) - a * x0 - b
                mu = torch.log(x0) + 1 -a
                x0 = x0 - (zi + eps)/(mu + eps)
                x0 = torch.clamp(x0, eps, 999999)
                # print(i, ((x0-_x0)**2).sum())
                # print(i)
                l0, l1 = verify(torch.log(x0), l, s_expert, s)
                if l0 < lmin:
                    lmin, lminx = l0, torch.log(x0)
            return lminx
        def calc2(s_expert, s, l, eps=1e-20):
            s_expert_p = F.softmax(s_expert, dim=1)
            s_p = F.softmax(s, dim=1)
            l_p = F.softmax(l, dim=1)
            a = l_p
            b = l_p * torch.log(l_p + eps)
            c = s_p * torch.log(s_p / (s_expert_p + eps))
            x0 = (c - b) / (a + eps)
            return x0
        def opt(s_expert, s, l, eps=1e-20):
            s_expert_p = F.softmax(s_expert, dim=1)
            s_p = F.softmax(s, dim=1)
            l_p = F.softmax(l, dim=1)
            a1 = torch.log(l_p + eps)
            b1 = s_expert_p * torch.log((s_expert_p + eps) / (s_p + eps))
            a2 = l_p
            b2 = l_p * torch.log(l_p + eps)
            c2 = s_p * torch.log(s_p / (s_expert_p + eps))
            
            x = Variable(l + (s_expert - s), requires_grad=True)
            learning_rate = 0.001
            optimizer = torch.optim.Adam([x], lr=learning_rate)

            for i in range(100):
                optimizer.zero_grad()
                x_p = F.softmax(x, dim=1)
                # loss = ((x_p * torch.log(x_p + eps) - a1 * x_p - b1)**2 + (a2 * torch.log(x_p + eps) - b2 + c2)**2).sum()
                loss = ((x_p * torch.log(x_p + eps) - a1 * x_p - b1)**2).sum()
                # l_gap = l_p * torch.log((l_p + eps) / (x_p + eps))
                # s_gap = s_p * torch.log((s_p + eps) / (s_expert_p + eps))
                # print('l like l expert', ((l_gap - s_gap) ** 2).sum())
                # loss = ((l_gap - s_gap) ** 2).sum()
                loss.backward()
                
                optimizer.step()
            # print(loss)
            return x.data.to("cuda:0")
        
        def opt2(s_expert, s, l, eps=1e-20):
            x = Variable(torch.ones(s_expert.shape[0], 1, device="cuda:0"), requires_grad=True)
            learning_rate = 0.001
            optimizer = torch.optim.Adam([x], lr=learning_rate)

            for i in range(100):
                optimizer.zero_grad()
                l_expert = l + x * (s_expert - s)
                l0, l1 = verify(l_expert, l, s_expert, s)
                loss = l0 + l1

                loss.backward()
                
                optimizer.step()
            # print(loss)
            l_expert = l + x * (s_expert - s)
            return l_expert
        
        for i in range(3):
            logits_all[i] = logits_all[i].to(torch.float32)
        # logits_x = alpha*(logits_all[2] - logits_all[1]) + logits_all[0]
        # l0, l1 = verify(logits_x, logits_all[0], logits_all[2], logits_all[1])
        # print('sub', l0, l1)
        # logits = logits_x
        # import pdb;pdb.set_trace()
        # logits = opt(logits_all[2], logits_all[1], logits_all[0])
        # l0, l1 = verify(logits, logits_all[0], logits_all[2], logits_all[1])
        # print('opt', l0, l1)
        # logits = logits_x
        # topk_values, topk_indices = torch.topk(logits, k=5, dim=1)
        # output = torch.zeros_like(logits)
        # torch.where(topk_indices)
        # for i in range(logits.shape[0]):
        #     output[i][topk_indices[i]] = topk_values[i]
        # logits = output
        minx, alpha_x = 99999, 0
        # qq = ((1/pp)*torch.exp(torch.clamp(torch.log(torch.clamp(pp, 0.1, 99))*logits_all[1]/logits_all[0],-999999,10)))
        # stop_id = [1209, 2023]
        # if "code" in ".".join(model_str_list):
        #     for i in range(len(logits_all)):
        #         logits_all[i][:, stop_id] = -9999
        # for i in np.arange(0.4, 0.7, 0.1):
        for i in np.arange(downa, upa, 0.1):
            logits_x = i*(logits_all[2] - logits_all[1]) + logits_all[0]
            
            l0, l1 = verify(logits_x, logits_all[0], logits_all[2], logits_all[1], i)
            if l0+l1 < minx:
                minx, alpha_x = l0+l1, i
        # fo.write(str(alpha_x) + "\n")
        print(alpha_x)
        # sys.stdout.flush()
        # alpha_x = 1.0
        # print(alpha_x)
        logits = alpha_x*(logits_all[2] - logits_all[1]) + logits_all[0]
        # import pdb;pdb.set_trace()
        # logits = opt2(logits_all[2], logits_all[1], logits_all[0])
        # logits = logits_all[0]
        # l0, l1 = verify(logits, logits_all[0], logits_all[2], logits_all[1])
        # print(alpha_x, l0, l1)
        # logits = logits_all[0] * (logits_all[2] + 1) / (logits_all[1] + 1)
        # import pdb;pdb.set_trace()
        # l0, l1 = verify(logits, logits_all[0], logits_all[2], logits_all[1])
        # print('sub', l0, l1)
        # logits_low = calc(logits_all[2], logits_all[1], logits_all[0])
        # l0, l1 = verify(logits_low, logits_all[0], logits_all[2], logits_all[1])
        # print('low', l0, l1)
        # logits = logits_low
        # logits_low = calc2(logits_all[2], logits_all[1], logits_all[0])
        # l0, l1 = verify(logits_low, logits_all[0], logits_all[2], logits_all[1])
        # print('low2', l0, l1)
        for i, client_socket in enumerate(client_sockets):
            send_large_tensor(client_socket, logits.to(torch.float16))

def main(*,
         model_str_list: str="[meta-llama/Llama-2-13b-hf,/xx/alt/models/llama2-gsm-7b,meta-llama/Llama-2-7b-hf]",
         alpha: float=0.9,
         downa: float=0.0,
         upa: float=2.0,
):
    model_str_list = model_str_list[1:-1].split(",")
    thread1 = threading.Thread(target=num_curr_seqs_task, args=(11455,))
    thread2 = threading.Thread(target=logits_task, args=(11454, model_str_list, alpha, upa, downa))

    # 启动线程
    thread1.start()
    thread2.start()

    # 等待两个线程结束
    thread1.join()
    thread2.join()

if __name__ == '__main__':
    import defopt
    try:
        defopt.run(main)
    except:
        import sys,pdb,bdb
        type, value, tb = sys.exc_info()
        if type == bdb.BdbQuit:
            exit()
        print(type,value)
        pdb.post_mortem(tb)
    