import socket
import pickle
import os
import threading
from vllm_inject.utils import *
import torch
import torch.nn.functional as F
import copy
import numpy as np
from torch.autograd import Variable

def num_curr_seqs_task(port):
    server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    host = socket.gethostname()
    port = port
    while 1:
        try:
            server_socket.bind((host, port))
            break
        except:
            continue
    server_socket.listen(5)
    print("num_curr_seqs Socket server is listening...")
    # import pdb;pdb.set_trace()
    model_number_dict = {}
    run_dict = {}
    while True:
        # 建立客户端连接
        client_sockets = []
        for _ in range(5):
            client_socket, addr = server_socket.accept()
            client_sockets.append(client_socket)
        num_curr_seqs = 999999999
        for i, client_socket in enumerate(client_sockets):
            # logits_data = client_socket.recv(4096)
            # logits = pickle.loads(logits_data)
            num_curr_seqs_data = client_socket.recv(4096)
            model_name, num_curr_seqs_now, run_now = pickle.loads(num_curr_seqs_data)
            if model_name not in model_number_dict:
                model_number_dict[model_name] = 0
                run_dict[model_name] = 0
            model_number_dict[model_name] += 1
            run_dict[model_name] += run_now
            # print(model_name)
            num_curr_seqs = min(num_curr_seqs_now, num_curr_seqs)
        # print(run_dict)
        for i, client_socket in enumerate(client_sockets):
            result_data = pickle.dumps(num_curr_seqs)
            client_socket.send(result_data)
            client_socket.close()

def logits_task(port, model_str_list, alpha):
    server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    host = socket.gethostname()
    port = port
    while 1:
        try:
            server_socket.bind((host, port))
            break
        except:
            continue
    server_socket.listen(len(model_str_list))
    print("logits Socket server is listening...")
    # import pdb;pdb.set_trace()
    model_number_dict = {}
    model_dict_id = {}
    for idx, k in enumerate(model_str_list):
        model_dict_id[k] = idx
    #/xx/alt/models/llama2-gsm-7b
    #meta-llama/Llama-2-7b-hf
    while True:
        # 建立客户端连接
        client_sockets = []
        for _ in range(5):
            client_socket, addr = server_socket.accept()
            client_sockets.append(client_socket)
        logits_all = [None for i in range(len(model_str_list))]
        # import pdb;pdb.set_trace()
        for i, client_socket in enumerate(client_sockets):
            # logits_data = client_socket.recv(4096)
            # logits = pickle.loads(logits_data)
            model_name, logits_now = receive_large_tensor(client_socket)
            if model_name not in model_number_dict:
                model_number_dict[model_name] = 0
            model_number_dict[model_name] += 1
            # print(model_name)
            logits_all[model_dict_id[model_name]] = logits_now
        # import pdb;pdb.set_trace()

        def verify(l_expert, l, s_expert, s, eps=1e-20):
            # import pdb;pdb.set_trace()
            s_expert_p = F.softmax(s_expert, dim=1)
            l_expert_p = F.softmax(l_expert, dim=1)
            s_p = F.softmax(s, dim=1)
            l_p = F.softmax(l, dim=1)
            l_gap = l_expert_p * torch.log((l_expert_p + eps) / (l_p + eps))
            s_gap = s_expert_p * torch.log((s_expert_p + eps) / (s_p + eps))
            # print('l expert like l', ((l_gap - s_gap) ** 2).sum())
            l0 = ((l_gap - s_gap) ** 2).sum()
            l_gap = l_p * torch.log((l_p + eps) / (l_expert_p + eps))
            s_gap = s_p * torch.log((s_p + eps) / (s_expert_p + eps))
            # print('l like l expert', ((l_gap - s_gap) ** 2).sum())
            l1 = ((l_gap - s_gap) ** 2).sum()
            return l0, l1 

        def calc(s_expert, s, l, eps=1e-20):
            s_expert_p = F.softmax(s_expert, dim=1)
            s_p = F.softmax(s, dim=1)
            l_p = F.softmax(l, dim=1)
            a = torch.log(l_p + eps)
            b = s_expert_p * torch.log((s_expert_p + eps) / (s_p + eps))
            x0 = torch.exp(a - 1) + eps
            lmin, lminx = 999999, 0
            for i in range(100):
                # _x0 = copy.deepcopy(x0)
                zi = x0 * torch.log(x0) - a * x0 - b
                mu = torch.log(x0) + 1 -a
                x0 = x0 - (zi + eps)/(mu + eps)
                x0 = torch.clamp(x0, eps, 999999)
                # print(i, ((x0-_x0)**2).sum())
                # print(i)
                l0, l1 = verify(torch.log(x0), l, s_expert, s)
                if l0 < lmin:
                    lmin, lminx = l0, torch.log(x0)
            return lminx
        def calc2(s_expert, s, l, eps=1e-20):
            s_expert_p = F.softmax(s_expert, dim=1)
            s_p = F.softmax(s, dim=1)
            l_p = F.softmax(l, dim=1)
            a = l_p
            b = l_p * torch.log(l_p + eps)
            c = s_p * torch.log(s_p / (s_expert_p + eps))
            x0 = (c - b) / (a + eps)
            return x0
        def opt(s_expert, s, l, eps=1e-20):
            s_expert_p = F.softmax(s_expert, dim=1)
            s_p = F.softmax(s, dim=1)
            l_p = F.softmax(l, dim=1)
            a1 = torch.log(l_p + eps)
            b1 = s_expert_p * torch.log((s_expert_p + eps) / (s_p + eps))
            a2 = l_p
            b2 = l_p * torch.log(l_p + eps)
            c2 = s_p * torch.log(s_p / (s_expert_p + eps))
            
            x = Variable(l + (s_expert - s), requires_grad=True)
            learning_rate = 0.001
            optimizer = torch.optim.Adam([x], lr=learning_rate)

            for i in range(100):
                optimizer.zero_grad()
                x_p = F.softmax(x, dim=1)
                # loss = ((x_p * torch.log(x_p + eps) - a1 * x_p - b1)**2 + (a2 * torch.log(x_p + eps) - b2 + c2)**2).sum()
                loss = ((x_p * torch.log(x_p + eps) - a1 * x_p - b1)**2).sum()
                # l_gap = l_p * torch.log((l_p + eps) / (x_p + eps))
                # s_gap = s_p * torch.log((s_p + eps) / (s_expert_p + eps))
                # print('l like l expert', ((l_gap - s_gap) ** 2).sum())
                # loss = ((l_gap - s_gap) ** 2).sum()
                loss.backward()
                
                optimizer.step()
            # print(loss)
            return x.data.to("cuda:0")
        def opt2(logits_all, eps=1e-20):
            x = Variable(torch.ones(logits_all[0].shape[0], 1, device="cuda:0"), requires_grad=True)
            y = Variable(torch.ones(logits_all[0].shape[0], 1, device="cuda:0"), requires_grad=True)
            z = Variable(torch.ones(logits_all[0].shape[0], 1, device="cuda:0"), requires_grad=True)
            learning_rate = 0.001
            optimizer = torch.optim.Adam([x], lr=learning_rate)

            for i in range(200):
                optimizer.zero_grad()
                l_expert = x*(logits_all[2] - logits_all[1]) + y*(logits_all[3] - logits_all[1]) + z*(logits_all[4] - logits_all[1])  + logits_all[0]
                l0, l1 = verify(l_expert, logits_all[0], logits_all[2], logits_all[1])
                l2, l3 = verify(l_expert, logits_all[0], logits_all[3], logits_all[1])
                l4, l5 = verify(l_expert, logits_all[0], logits_all[4], logits_all[1])
                
                loss = l0 + l1 + l2 + l3 + l4 + l5

                loss.backward()
                
                optimizer.step()
            # print(loss)
            l_expert = x*(logits_all[2] - logits_all[1]) + y*(logits_all[3] - logits_all[1]) + z*(logits_all[4] - logits_all[1])  + logits_all[0]
            return l_expert
        
        def opt3(logits_all, eps=1e-20):
            x = Variable(torch.tensor([0.3], device="cuda:0"), requires_grad=True)
            y = Variable(torch.tensor([0.3], device="cuda:0"), requires_grad=True)
            z = Variable(torch.tensor([0.3], device="cuda:0"), requires_grad=True)
            learning_rate = 0.01
            optimizer = torch.optim.Adam([x, y, z], lr=learning_rate)

            for i in range(100):
                optimizer.zero_grad()
                l_expert = x*(logits_all[2] - logits_all[1]) + y*(logits_all[3] - logits_all[1]) + z*(logits_all[4] - logits_all[1])  + logits_all[0]

                def compute_loss():
                    l0, l1 = verify(l_expert, logits_all[0], logits_all[2], logits_all[1])
                    l2, l3 = verify(l_expert, logits_all[0], logits_all[3], logits_all[1])
                    l4, l5 = verify(l_expert, logits_all[0], logits_all[4], logits_all[1])
                    loss = l0 + l1 + l2 + l3 + l4 + l5
                    return loss

                def compute_loss2():

                    eps = 1e-20

                    # l_expert , logits_all[0]
                    # 2,3,4 ,  1
                    # [n, l, vocab]
                    expert_logits_all = torch.stack((logits_all[2], logits_all[3], logits_all[4]), dim=0)
                    expert_logits_all = F.softmax(expert_logits_all, dim=-1)
                    # [l, vocab]
                    _logits_all = F.softmax(logits_all[1], dim=-1)

                    left = F.softmax(l_expert, dim=-1) * torch.log((F.softmax(l_expert, dim=-1) + eps) / (F.softmax(logits_all[0], dim=-1) + eps))

                    right = (expert_logits_all[0] * expert_logits_all[1] * expert_logits_all[2]) * (torch.log(expert_logits_all).sum(dim=0) - 3 * torch.log(_logits_all))

                    _left = F.softmax(logits_all[0], dim=-1) * torch.log((F.softmax(logits_all[0], dim=-1) + eps) / (F.softmax(l_expert, dim=-1) + eps))
                    _right = (_logits_all * _logits_all * _logits_all) * (torch.log(_logits_all).sum() - 3 * torch.log(expert_logits_all))

                    loss = ((left - right) ** 2).sum() + ((_left - _right) ** 2).sum()

                    return loss

                loss = compute_loss()
                loss.backward()
                
                optimizer.step()
            # print(loss)
            l_expert = x*(logits_all[2] - logits_all[1]) + y*(logits_all[3] - logits_all[1]) + z*(logits_all[4] - logits_all[1])  + logits_all[0]
            return l_expert
        
        for i in range(len(model_str_list)):
            logits_all[i] = logits_all[i].to(torch.float32)
        # logits_x = alpha*(logits_all[2] - logits_all[1]) + logits_all[0]
        # l0, l1 = verify(logits_x, logits_all[0], logits_all[2], logits_all[1])
        # print('sub', l0, l1)
        # logits = logits_x
        # import pdb;pdb.set_trace()
        # logits = opt(logits_all[2], logits_all[1], logits_all[0])
        # l0, l1 = verify(logits, logits_all[0], logits_all[2], logits_all[1])
        # print('opt', l0, l1)
        # logits = logits_x
        # topk_values, topk_indices = torch.topk(logits, k=5, dim=1)
        # output = torch.zeros_like(logits)
        # torch.where(topk_indices)
        # for i in range(logits.shape[0]):
        #     output[i][topk_indices[i]] = topk_values[i]
        # logits = output
        minx, alpha_x = 999999, 0
        # logit0 = 0
        for i in np.arange(0.0, 1.0, 0.1):
            for j in np.arange(0.0, min(1.0-i, 1.0), 0.1):
                for k in np.arange(0.0, min(1.0-i-j,1.0), 0.1):
                    logits_x = i*(logits_all[2] - logits_all[1]) + j*(logits_all[3] - logits_all[1]) + k*(logits_all[4] - logits_all[1])  + logits_all[0]
                    l0, l1 = verify(logits_x, logits_all[0], logits_all[2], logits_all[1])
                    l2, l3 = verify(logits_x, logits_all[0], logits_all[3], logits_all[1])
                    l4, l5 = verify(logits_x, logits_all[0], logits_all[4], logits_all[1])
                    if l0+l1+l2+l3+l4+l5 < minx:
                        minx, alpha_x, logit0 = l0+l1+l2+l3+l4+l5, (i,j,k), logits_x
        # print(alpha_x)
        # logit0 = opt3(logits_all)
        # logits = logits_all[0]
        # for i in np.arange(0.5, 2.5, 0.1):
        #     logits_x = i*(logits_all[2] - logits_all[1]) + logits_all[0]
        #     l0, l1 = verify(logits_x, logits_all[0], logits_all[2], logits_all[1])
        #     l2, l3 = verify(logits_x, logits_all[0], logits_all[3], logits_all[1])
        #     if l0+l1+l2+l3 < minx:
        #         minx, alpha_x, logit0 = l0+l1+l2+l3, i, logits_x
        #     logits_x = i*(logits_all[3] - logits_all[1]) + logits_all[0]
        #     l0, l1 = verify(logits_x, logits_all[0], logits_all[3], logits_all[1])
        #     l2, l3 = verify(logits_x, logits_all[0], logits_all[2], logits_all[1])
        #     if l0+l1+l2+l3 < minx:
        #         minx, alpha_x, logit0 = l0+l1+l2+l3, i, logits_x
        # if kl0 < k20:
        #     logits = 1.0*(logits_all[2] - logits_all[1]) + logits_all[0]
        # else:
        #     logits = 1.0*(logits_all[3] - logits_all[1]) + logits_all[0]
        logits = logit0
        # logits = logits_all[0]
        # l0, l1 = verify(logits, logits_all[0], logits_all[2], logits_all[1])
        # print(alpha_x, l0, l1)
        # logits = logits_all[0] * (logits_all[2] + 1) / (logits_all[1] + 1)
        # import pdb;pdb.set_trace()
        # l0, l1 = verify(logits, logits_all[0], logits_all[2], logits_all[1])
        # print('sub', l0, l1)
        # logits_low = calc(logits_all[2], logits_all[1], logits_all[0])
        # l0, l1 = verify(logits_low, logits_all[0], logits_all[2], logits_all[1])
        # print('low', l0, l1)
        # logits = logits_low
        # logits_low = calc2(logits_all[2], logits_all[1], logits_all[0])
        # l0, l1 = verify(logits_low, logits_all[0], logits_all[2], logits_all[1])
        # print('low2', l0, l1)
        for i, client_socket in enumerate(client_sockets):
            send_large_tensor(client_socket, logits.to(torch.float16))

def main(*,
         model_str_list: str="[meta-llama/Llama-2-13b-hf,/xx/alt/models/llama2-gsm-7b,meta-llama/Llama-2-7b-hf]",
         alpha: float=0.9,
):
    model_str_list = model_str_list[1:-1].split(",")
    thread1 = threading.Thread(target=num_curr_seqs_task, args=(11455,))
    thread2 = threading.Thread(target=logits_task, args=(11454, model_str_list, alpha))

    # 启动线程
    thread1.start()
    thread2.start()

    # 等待两个线程结束
    thread1.join()
    thread2.join()

if __name__ == '__main__':
    import defopt
    try:
        defopt.run(main)
    except:
        import sys,pdb,bdb
        type, value, tb = sys.exc_info()
        if type == bdb.BdbQuit:
            exit()
        print(type,value)
        pdb.post_mortem(tb)
    