import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np
import pickle
from time import time

from os.path import exists
import os

class Comm(nn.Module):
    def __init__(self, input_size=20+11, h_dim1=256, output_size=10, h_dim2=64, n_class=10, n_agent=1):
        super(Comm, self).__init__()

        self.n_agent = n_agent

        self.agent_colorblinds = dict()
        self.agent_shapeblinds = dict()
        self.output_size = output_size

        for agent_id in range(n_agent):
            self.agent_colorblinds[agent_id] = nn.Sequential(
                nn.Linear(input_size, h_dim1),
                nn.ReLU(),
                nn.Linear(h_dim1, output_size),
                # nn.Softmax()
                # nn.Sigmoid()
            )
            self.add_module('colorblind_{}'.format(agent_id), self.agent_colorblinds[agent_id])

        for agent_id in range(n_agent):
            self.agent_shapeblinds[agent_id] = nn.Sequential(
                nn.Linear(input_size, h_dim1),
                nn.ReLU(),
                nn.Linear(h_dim1, output_size),
                # nn.Softmax()
                # nn.Sigmoid()
            )
            self.add_module('shapeblind_{}'.format(agent_id), self.agent_shapeblinds[agent_id])

        self.listeners = dict()
        for idx in range(2*n_agent):
            self.listeners[idx] = nn.Sequential(
                nn.Linear(output_size*2, h_dim2),
                # nn.Linear(output_size, h_dim2),
                nn.ReLU(),
                nn.Linear(h_dim2, n_class),
                nn.Softmax()
            )
            self.add_module('merger_{}'.format(idx), self.listeners[idx])

    def forward(self, x1, x2, colorblind_id, shapeblind_id, speaker, cont=True): 
        h1 = self.agent_colorblinds[colorblind_id](x1)
        h2 = self.agent_shapeblinds[shapeblind_id](x2)

        if not cont:
            # if speaker == 0: # Listener is colorblind; discretize shapeblind
            #     h2 = discretize_message(h2, self.output_size)
            # else:
            #     h1 = discretize_message(h1, self.output_size)
            h1 = discretize_message(h1, self.output_size)
            h2 = discretize_message(h2, self.output_size)

        h = torch.cat((h1, h2), axis=1)
        # h = h1 + h2
        if speaker == 0:
            output = self.listeners[colorblind_id](h)
        else:
            output = self.listeners[self.n_agent + shapeblind_id](h)

        return output

def discretize_message(message, output_size):
    hs = []
    for i in range(output_size // 2):
        h_small = message[:, 2*i: 2*(i+1)]
        h_small = F.gumbel_softmax(h_small, hard=True)
        hs.append(h_small)
    discrete_message = torch.cat(hs, axis=1)
    return discrete_message

def early_stopping(val_accs, patience=30):
    l = len(val_accs)

    if l < 2*patience:
        return False
    current_best = np.max(val_accs)
    if current_best > np.max(val_accs[l-1-patience:]):
        return True
    else:
        return False

if __name__ == '__main__':
    import sys

    cont = False
    contordiscrete = 'cont' if cont else 'disc'

    bs = 256

    # agent_candidates = [v for v in np.arange(1, 33) if v%2 == 0]
    # agent_trials = dict()
    # path = '../models/mult_comm/'
    # fns = os.listdir(path)
    # for agent_id in agent_candidates:
    #     fn = 'num_agent_{}_cont_esYes'.format(agent_id)
    #     agent_trials[agent_id] = len([v for v in fns if fn in v])
    # agent_candidates = [v for v in agent_candidates if agent_trials[v] < 5]
    # agent_candidates += [5, 9]
    n_agent = 64

    # n_agent = int(sys.argv[1])
    if cont:
        output_size = 10
    else:
        output_size = 20

    X_colorblinds = dict()
    X_colorblinds_val = dict()

    for agent_id in range(n_agent):
        tmp = torch.from_numpy(np.load('../data/X/colorblind/X_{}_train.npy'.format(agent_id))).float()
        l_tmp = len(tmp)
        n_train = int(l_tmp * 0.9)
        X_colorblinds[agent_id] = tmp[:n_train]
        X_colorblinds_val[agent_id] = tmp[n_train:]

    X_shapeblinds = dict()
    X_shapeblinds_val = dict()
    for agent_id in range(n_agent):
        tmp = torch.from_numpy(np.load('../data/X/shapeblind/X_{}_train.npy'.format(agent_id))).float()
        l_tmp = len(tmp)
        n_train = int(l_tmp * 0.9)
        X_shapeblinds[agent_id] = tmp[:n_train]
        X_shapeblinds_val[agent_id] = tmp[n_train:]

    Y = np.load('../data/Y_train.npy')
    Y = torch.from_numpy(Y)
    Y = Y.cuda().float()
    Y_val = Y[n_train:]
    Y_train = Y[:n_train]

    comm = Comm(output_size=output_size, n_agent=n_agent)
    if torch.cuda.is_available():
        comm.cuda()

    # optimizer = torch.optim.Adam(comm.parameters(), lr=1e-4)
    optimizer_colorblinds = dict()
    optimizer_shapeblinds = dict()
    optimizer_listeners = dict()
    for i in range(n_agent):
        optimizer_colorblinds[i] = torch.optim.Adam(comm.agent_colorblinds[i].parameters(), lr=1e-4)
        optimizer_shapeblinds[i] = torch.optim.Adam(comm.agent_shapeblinds[i].parameters(), lr=1e-4)

        optimizer_listeners[i] = torch.optim.Adam(comm.listeners[i].parameters(), lr=1e-4)
        optimizer_listeners[n_agent+i] = torch.optim.Adam(comm.listeners[n_agent+i].parameters(), lr=1e-4)

    epochs = 1000000

    n_batch = len(X_colorblinds[0]) // bs

    losses = []
    training_accs = []
    val_accs = []
    enable_early_stopping = True

    start_t = time()
    for epoch in range(epochs):
        for idx in range(n_batch):
            rand_idx = np.random.randint(n_batch)
            colorblind_id, shapeblind_id = np.random.randint(n_agent, size=2)
            speaker = np.random.randint(2)

            x_colorblind = X_colorblinds[colorblind_id][bs*rand_idx:bs*(rand_idx+1)].cuda()
            x_shapeblind = X_shapeblinds[shapeblind_id][bs*rand_idx:bs*(rand_idx+1)].cuda()

            y = Y[bs*rand_idx:bs*(rand_idx+1)]
            output = comm(x_colorblind, x_shapeblind, colorblind_id, shapeblind_id, speaker, cont)
            loss = F.binary_cross_entropy(output, y, size_average=False)

            optimizer_colorblinds[colorblind_id].zero_grad()
            optimizer_shapeblinds[shapeblind_id].zero_grad()
            if speaker == 0:
                optimizer_listeners[colorblind_id].zero_grad()
            else:
                optimizer_listeners[n_agent + shapeblind_id].zero_grad()
            loss.backward()
            optimizer_colorblinds[colorblind_id].step()
            optimizer_shapeblinds[shapeblind_id].step()
            if speaker == 0:
                optimizer_listeners[colorblind_id].step()
            else:
                optimizer_listeners[n_agent + shapeblind_id].step()

        training_acc = sum(torch.argmax(y.cpu(), axis=1) == torch.argmax(output.cpu(), axis=1)) / float(len(y))

        x_colorblind_val = X_colorblinds_val[colorblind_id].cuda()
        x_shapeblind_val = X_shapeblinds_val[shapeblind_id].cuda()

        output_val = comm(x_colorblind_val, x_shapeblind_val, colorblind_id, shapeblind_id, speaker, cont)

        val_acc = sum(torch.argmax(Y_val.cpu(), axis=1) == torch.argmax(output_val.cpu(), axis=1)) / float(len(Y_val))

        print(n_agent, "Epoch[{}/{}] Loss: {:.8f}, training_acc: {:.4f}, val_acc: {:.4f}, time elapsed: {}".format(epoch+1, epochs, loss.data/bs, training_acc, val_acc, time() - start_t))
        losses.append(loss.data/bs)
        training_accs.append(training_acc)
        val_accs.append(val_acc)
        if enable_early_stopping:
            # if early_stopping(val_accs, patience=300*int(np.log2(n_agent+1))):
            if early_stopping(val_accs, patience=50*n_agent):
                break

    # ckpt_path='../models/mult_comm/num_agent_{}_{}_{}.pkl'.format(n_agent, contordiscrete, 'esYes' if enable_early_stopping else 'esNo')
    # results_path = '../results/num_agent_{}_{}_{}.pkl'.format(n_agent, contordiscrete, 'esYes' if enable_early_stopping else 'esNo')
    ckpt_path='../models/mult_comm/num_agent_{}_{}_{}_discretizeboth.pkl'.format(n_agent, contordiscrete, 'esYes' if enable_early_stopping else 'esNo')
    results_path = '../results/num_agent_{}_{}_{}_discretizeboth.pkl'.format(n_agent, contordiscrete, 'esYes' if enable_early_stopping else 'esNo')

    # if exists(ckpt_path):
    #     trial = 1
    #     while True:
    #         ckpt_path='../models/mult_comm/num_agent_{}_{}_{}_trial_{}.pkl'.format(n_agent, contordiscrete, 'esYes' if enable_early_stopping else 'esNo', trial)
    #         if exists(ckpt_path):
    #             trial += 1
    #         else:
    #             break
    #     ckpt_path='../models/mult_comm/num_agent_{}_{}_{}_trial_{}.pkl'.format(n_agent, contordiscrete, 'esYes' if enable_early_stopping else 'esNo', trial)        
    #     results_path = '../results/num_agent_{}_{}_{}_trial_{}.pkl'.format(n_agent, contordiscrete, 'esYes' if enable_early_stopping else 'esNo', trial)

    print("saving model to %s..." % ckpt_path)
    torch.save(comm.state_dict(), ckpt_path)

    results = {'losses': losses, 'training_accs': training_accs, 'val_accs': val_accs}
    pickle.dump(results, open(results_path, 'wb'))
