from scripts.multiagent_communication import *
import cma
import multiprocessing as mp
import pickle
import warnings
import os
import time

warnings.filterwarnings("ignore", category=DeprecationWarning) 
warnings.filterwarnings("ignore", category=UserWarning) 
if __name__ == '__main__':
    bs = 1024
    n_agent = 16
    epochs = 50

    X_colorblinds = dict()
    X_colorblinds_val = dict()

    for agent_id in range(n_agent):
        tmp = torch.from_numpy(np.load('data/X/colorblind/X_{}_train.npy'.format(agent_id))).float()
        l_tmp = len(tmp)
        n_train = int(l_tmp * 0.9)
        X_colorblinds[agent_id] = tmp[:n_train]
        X_colorblinds_val[agent_id] = tmp[n_train:]


    X_shapeblinds = dict()
    X_shapeblinds_val = dict()
    for agent_id in range(n_agent):
        tmp = torch.from_numpy(np.load('data/X/shapeblind/X_{}_train.npy'.format(agent_id))).float()
        l_tmp = len(tmp)
        n_train = int(l_tmp * 0.9)
        X_shapeblinds[agent_id] = tmp[:n_train]
        X_shapeblinds_val[agent_id] = tmp[n_train:]

    Y_numpy = np.load('data/Y_train.npy')
    Y = torch.from_numpy(Y_numpy)
    Y = Y.float()
    Y_val = Y[n_train:]
    Y_train = Y[:n_train]

    def sigmoid(x):
        return 1 / (1 + np.exp(-x))

    def cosine(v1, v2):
        return np.sum(v1 * v2, axis=1) / np.linalg.norm(v1, axis=1) / np.linalg.norm(v2, axis=1)

    def get_links(solution, l_per_node=3):
        assert solution.shape == (32, 16)
        solution1 = solution[:16]
        links = []
        for i in range(len(solution1)):
            row = solution1[i]
            for ii in np.argsort(row)[-1*l_per_node:]:
                links.append((i, ii, 0)) # 0: sender: colorblind_id
            
        solution2 = solution[16:]
        for j in range(len(solution2)):
            col = solution2[:, j]
            for jj in np.argsort(col)[-1*l_per_node:]:
                links.append((jj, j, 1)) # 1: sender: shapeblind_id
        return links

    def blackbox_func(solution, cpu_id):
        torch.manual_seed(cpu_id + int(time.time()))
        torch.set_num_threads(1)

        solution = solution.reshape((n_agent*2, n_agent))
        # solution = sigmoid(solution)
        # colorblind_ids, shapeblind_ids = np.where(solution > 0.5)
        links = get_links(solution)
        colorblind_ids, shapeblind_ids, senders = zip(*links)
        # assert(len(colorblind_ids) == len(shapeblind_ids))
        n_links =len(colorblind_ids)
        comm = Comm(n_agent=n_agent)

        optimizer_colorblinds = dict()
        optimizer_shapeblinds = dict()
        optimizer_listeners = dict()
        for i in range(n_agent):
            optimizer_colorblinds[i] = torch.optim.Adam(comm.agent_colorblinds[i].parameters(), lr=1e-2)
            optimizer_shapeblinds[i] = torch.optim.Adam(comm.agent_shapeblinds[i].parameters(), lr=1e-2)

            optimizer_listeners[i] = torch.optim.Adam(comm.listeners[i].parameters(), lr=1e-2)
            optimizer_listeners[n_agent+i] = torch.optim.Adam(comm.listeners[n_agent+i].parameters(), lr=1e-2)

        n_batch = len(X_colorblinds[0]) // bs

        for epoch in range(epochs):
            for idx in range(n_batch):
                link_id = np.random.randint(n_links)
                colorblind_id, shapeblind_id, sender = colorblind_ids[link_id], shapeblind_ids[link_id], senders[link_id]

                rand_idx = np.random.randint(n_batch)
                x_colorblind = X_colorblinds[colorblind_id][bs*rand_idx:bs*(rand_idx+1)]
                x_shapeblind = X_shapeblinds[shapeblind_id][bs*rand_idx:bs*(rand_idx+1)]

                y = Y_train[bs*rand_idx:bs*(rand_idx+1)]
                output = comm(x_colorblind, x_shapeblind, colorblind_id, shapeblind_id, sender)

                loss = F.binary_cross_entropy(output, y, size_average=False)

                optimizer_colorblinds[colorblind_id].zero_grad()
                optimizer_shapeblinds[shapeblind_id].zero_grad()
                if sender == 0:
                    optimizer_listeners[colorblind_id].zero_grad()
                else:
                    optimizer_listeners[n_agent + shapeblind_id].zero_grad()
                loss.backward()
                optimizer_colorblinds[colorblind_id].step()
                optimizer_shapeblinds[shapeblind_id].step()
                if sender == 0:
                    optimizer_listeners[colorblind_id].step()
                else:
                    optimizer_listeners[n_agent + shapeblind_id].step()
            
            # with torch.no_grad():
            #     output = comm(x_colorblind, x_shapeblind, colorblind_id, shapeblind_id, sender)
            # result = float(sum(torch.argmax(y, axis=1) == torch.argmax(output, axis=1)) / float(len(y)))
            # print(result)
            # print(epoch)
        board = []
        for colorblind_id in range(n_agent):
            x_colorblind = X_colorblinds_val[colorblind_id]
            for shapeblind_id in range(n_agent):
                if (colorblind_id, shapeblind_id) in zip(colorblind_ids, shapeblind_ids): continue
                x_shapeblind = X_shapeblinds_val[shapeblind_id]
                with torch.no_grad():
                    output = comm(x_colorblind, x_shapeblind, colorblind_id, shapeblind_id, np.random.randint(2))

                result = float(sum(torch.argmax(Y_val, axis=1) == torch.argmax(output, axis=1)) / float(len(Y_val)))

                board.append(result)
        return(-1 * (np.mean(board)))

    os.chdir('es/')
    es = cma.CMAEvolutionStrategy(2*n_agent*n_agent*[0], 0.1, {'popsize': 64, 'verb_disp': 1})
    cnt = 0
    bests = []
    while not es.stop():
        solutions = es.ask()
        solutions_ = [] + solutions
        pool = mp.Pool(64*2)
        tmp = solutions * 2
        result_tmp = pool.starmap(blackbox_func, zip(tmp, range(64*2)))
        result = [np.mean([result_tmp[16*ii + jj] for ii in range(2)]) for jj in range(64)]
        es.tell(solutions_, result)
        es.logger.add()
        es.disp()
        pool.close()
        pool.terminate()
        bests.append(es.result.xbest)
        if cnt % 10 == 0:
            pickle.dump(es, open('es.pkl', 'wb'))
            pickle.dump(bests, open('es_bests.pkl', 'wb'))
        cnt += 1