import argparse
import random
import sys
import datetime
from pathlib import Path
from tempfile import mkdtemp

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from dataloader import DsrpitesDataset
from utils import Logger, display_losses, save_toFile, gumbel_softmax, straight_through_discretize
from models.base_model import Encoder, Projector


class LanguageCoder(nn.Module):
    def __init__(self, latent_dim, word_length, dictionary_size, device, temperature=1.0):
        super(LanguageCoder, self).__init__()
        self.input_size = latent_dim
        self.word_length = word_length
        self.dictionary_size = dictionary_size
        self.device = device
        self.temperature = temperature

        self.encoder_lstm = nn.LSTMCell(self.input_size, self.input_size)
        self.decoder_lstm = nn.LSTMCell(self.input_size, self.input_size)
        self.hidden_to_token = nn.Linear(self.input_size, self.dictionary_size)
        self.token_to_hidden = nn.Linear(self.dictionary_size, self.input_size)

        self.output_mean = nn.Linear(self.input_size, self.input_size)
        self.output_logvar = nn.Linear(self.input_size, self.input_size)

    def Encoder(self, x):
        one_hot, logits, messages = [], [], []
        batch_size = x.shape[0]

        hx = torch.zeros(batch_size, self.input_size, device=self.device)
        cx = x
        lstm_input = torch.zeros(batch_size, self.input_size, device=self.device)

        for num in range(self.word_length):
            hx, cx = self.encoder_lstm(lstm_input, (hx, cx))
            pre_logits = self.hidden_to_token(hx)
            logits.append(pre_logits)

            if self.training:
                z_sampled_soft = gumbel_softmax(pre_logits, self.temperature)
            else:
                z_sampled_soft = torch.softmax(pre_logits, dim=-1)

            z_sampled_onehot, word = straight_through_discretize(z_sampled_soft)
            one_hot.append(z_sampled_onehot)
            messages.append(word)
            lstm_input = self.token_to_hidden(z_sampled_onehot)

        logits = torch.stack(logits).permute(1, 0, 2)
        one_hots = torch.stack(one_hot).permute(1, 0, 2)
        messages = torch.stack(messages).t()
        return logits, one_hots, messages

    def Decoder(self, z):
        batch_size = z.shape[0]
        z = z.contiguous().view(-1, z.shape[-1])
        z_embeddings = self.token_to_hidden(z)
        z_embeddings = z_embeddings.view(batch_size, self.word_length, -1)
        hx = torch.zeros(batch_size, self.input_size, device=self.device)
        cx = torch.zeros(batch_size, self.input_size, device=self.device)

        for n in range(self.word_length):
            inputs = z_embeddings[:, n]
            hx, cx = self.decoder_lstm(inputs, (hx, cx))
        return hx

    def forward(self, input):
        one_hot_tokens, logits, messages = self.Encoder(input)
        recons = self.Decoder(one_hot_tokens)
        return recons, one_hot_tokens, logits, messages


class SimSiamVAE(nn.Module):
    def __init__(self, feature_dim, latent_dim, word_length, dictionary_size, device, backbone='linear-dsprites'):
        super(SimSiamVAE, self).__init__()
        self.word_length = word_length
        self.dictionary_size = dictionary_size
        self.device = device
        self.backbone = Encoder(backbone=backbone)
        self.perception = Projector(feature_dim=feature_dim, latent_dim=latent_dim)
        self.langCoder = LanguageCoder(latent_dim=latent_dim, word_length=word_length, dictionary_size=dictionary_size,
                                       device=device, temperature=1.0)

    def forward(self, x):
        y = self.backbone(x)
        latent = self.perception(y)
        logit, onehot, message = self.langCoder.Encoder(latent)
        latent_recon = self.langCoder.Decoder(onehot)
        return logit, latent, latent_recon, onehot, message

    def speak(self, x):
        y = self.backbone(x)
        latent = self.perception(y)
        _, onehot, message = self.langCoder.Encoder(latent)
        return onehot, message

    def listen(self, onehot):
        latent_recon = self.langCoder.Decoder(onehot)
        return latent_recon

    def loss_fn(self, logit, latent, latent_recon, sp_latent, loss_type='cosine', loss_rec=False, beta=1.0):
        def compute_KLD_loss(logit):
            logits_dist = torch.distributions.OneHotCategorical(logits=logit)
            prior = torch.log(torch.tensor([1.0 / self.dictionary_size] * self.dictionary_size, device=self.device))
            prior_dist = torch.distributions.OneHotCategorical(logits=prior.expand_as(logit))
            kl = torch.distributions.kl_divergence(logits_dist, prior_dist)
            return kl.sum(1).mean()

        def negative_cosine_similarity(x, y):
            x = F.normalize(x, dim=-1)
            y = F.normalize(y, dim=-1)
            return - (x * y).sum(dim=-1).mean()

        def mse_loss(x, y):
            return F.mse_loss(x, y)

        loss_kld = compute_KLD_loss(logit)
        loss_recon = 0
        if loss_type == 'cosine':
            loss_similarity = negative_cosine_similarity(latent, sp_latent)
            if loss_rec:
                loss_recon = negative_cosine_similarity(latent, latent_recon)
        else:  # loss_type == 'mse':
            loss_similarity = mse_loss(latent, sp_latent)
            if loss_rec:
                loss_recon = mse_loss(latent, latent_recon)

        total_loss = loss_similarity + loss_recon + beta * loss_kld
        return total_loss, loss_similarity, loss_recon, loss_kld


class Agent:
    def __init__(self, args, name):
        super(Agent, self).__init__()
        self.args = args
        self.name = name
        self.runPath = args.run_path
        self.device = args.device
        self.D = 0

        self.feature_dim = args.feature_dim
        self.latent_dim = args.latent_dim
        self.word_length = args.word_length
        self.dictionary_size = args.dictionary_size
        self.batch_size = args.batch_size
        self.backbone = args.backbone
        self.loss_history = []
        self.latents = []
        self.messages = []
        self.latents_recon = []

        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.dataloader_train = None
        self.dataloader_test = None
        self.true_label = []
        self.initialize()

    def initialize(self):
        self.D = 18432
        label_link = 'labels.npy'
        if self.name == 'a':
            data_link = 'imgs_a.npy'
        else:
            data_link = 'imgs_b.npy'

        self.true_label = np.load(label_link)
        self.true_label = np.delete(self.true_label, [0, 3], axis=1)
        test_link = 'imgs_test.npy'
        dataset_train = DsrpitesDataset([data_link])
        dataset_test = DsrpitesDataset([test_link])

        self.dataloader_train = DataLoader(dataset_train, batch_size=self.batch_size, shuffle=False)
        self.dataloader_test = DataLoader(dataset_test, batch_size=self.batch_size, shuffle=False)
        self.model = SimSiamVAE(feature_dim=self.feature_dim, latent_dim=self.latent_dim, word_length=self.word_length,
                                dictionary_size=self.dictionary_size, backbone=self.backbone, device=self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=10, gamma=0.5)

    def play_game(self, speaker):
        self.model.train()
        speaker.model.eval()
        train_loss = 0
        recon_loss = 0
        similarity_loss = 0
        kld_loss = 0
        for data, sp_data in zip(self.dataloader_train, speaker.dataloader_train):
            data = data.view(data.size(0), -1).to(self.device)
            sp_data = sp_data.view(data.size(0), -1).to(speaker.device)
            self.optimizer.zero_grad()
            logit, latent, latent_recon, onehot, message = self.model(data)
            sp_onehot, _ = speaker.model.speak(sp_data)
            sp_latent = self.model.listen(sp_onehot)
            loss, loss_similarity, loss_recon, loss_kld = self.model.loss_fn(logit, latent, latent_recon, sp_latent)
            train_loss += loss.item()
            if recon_loss != 0:
                recon_loss += loss_recon.item()
            similarity_loss += loss_similarity.item()
            kld_loss += loss_kld.item()
            loss.backward()
            self.optimizer.step()
        avg_loss = train_loss / self.D
        avg_recon_loss = recon_loss / self.D
        avg_similarity_loss = similarity_loss / self.D
        avg_kld_loss = kld_loss / self.D
        self.loss_history.append(avg_loss)
        print(self.name + f' Avg Loss: {avg_loss:.4f}, SimSiam Loss: {avg_similarity_loss:.4f}, KLD Loss: {avg_kld_loss:.4f}, Recon Loss: {avg_recon_loss:.4f}')
        self.scheduler.step()

    def get_messages(self, speaker):
        self.model.eval()
        speaker.model.eval()
        latents, messages, latents_recon = [], [], []
        with torch.no_grad():
            for batch_idx, data in enumerate(self.dataloader_test):
                data = data.view(data.size(0), -1).to(self.device)
                _, latent, _, onehot, message = self.model(data)
                _, _, _, sp_onehot, _ = speaker.model(data)
                latent_recon = self.model.listen(sp_onehot)
                latents.append(latent.cpu().numpy())
                messages.append(message.cpu().numpy())
                latents_recon.append(latent_recon.cpu().numpy())
        self.latents = np.concatenate(latents, axis=0)
        self.messages = np.concatenate(messages, axis=0)
        self.latents_recon = np.concatenate(latents_recon, axis=0)


def args_define():
    parser = argparse.ArgumentParser(description='SimSiam Naming Game')
    parser.add_argument('--dataset', type=str, default='dsprites', help='Datasets [dsprites]')
    parser.add_argument('--backbone', type=str, default='linear', help='Backbone [linear, cnn]')
    parser.add_argument('--feature-dim', type=int, default=128, metavar='N', help='feature dim from backbone [default: 128]')
    parser.add_argument('--latent-dim', type=int, default=256, metavar='N', help='latent dim from projector [default: 256]')
    parser.add_argument('--word-length', type=int, default=10, metavar='L', help='word dimensionality (default: 10)')
    parser.add_argument('--dictionary-size', type=int, default=100, metavar='L', help='dictionary size (default: 100)')
    parser.add_argument('--mh-epochs', type=int, default=500, metavar='N', help='No of epochs of naming game [default: 100]')
    parser.add_argument('--batch-size', type=int, default=256, metavar='N', help='batch size of model [default: 64]')
    parser.add_argument('--learning-rate', type=float, default=1e-5, metavar='LR', help='learning rate [default: 1e-3]')
    parser.add_argument('--run-path', type=str, default=None, help='directory for saving models')
    parser.add_argument('--device', type=str, default='mps', help='device for training [mps, cuda, cpu]')
    parser.add_argument('--debug', type=bool, default=False, help='debug vs running')
    parser.add_argument('-f', '--file', help='Path for input file')
    return parser.parse_args()


def SimSiam_naming_game(args, A, B):
    print('Playing SimSiam Naming Game')
    for epoch in range(args.mh_epochs):
        print('====> Epoch: {}'.format(epoch))
        A.play_game(speaker=B)
        B.play_game(speaker=A)
    A.get_messages(speaker=B)
    B.get_messages(speaker=A)


def set_seeds(seed):
    if seed == -1:
        seed = random.randint(1, 100)
    torch.manual_seed(seed)
    np.random.seed(seed)
    print('Seed: {:.2g}'.format(seed))


def initialize(args):
    if args.debug:
        args.backbone = 'linear'
        args.self_train_epochs = 2
        args.mh_epochs = 2
        args.feature_dim = 128
        args.latent_dim = 16
        args.variable_dim = 16

    runId = datetime.datetime.now().isoformat()
    experiment_dir = Path('experiments/')
    experiment_dir.mkdir(parents=True, exist_ok=True)
    runPath = mkdtemp(prefix=runId, dir=str(experiment_dir))
    sys.stdout = Logger('{}/run.log'.format(runPath))
    print('Expt:', runPath)
    print('RunID:', runId)
    return runPath


def main():
    args = args_define()
    args.run_path = initialize(args) + '/'
    set_seeds(1)
    print(args)

    A = Agent(args=args, name='a')
    B = Agent(args=args, name='b')
    A.model.to(args.device)
    B.model.to(args.device)
    print(A.model)
    SimSiam_naming_game(args, A, B)
    torch.save(A.model.state_dict(), args.run_path + 'agentA.pth')
    torch.save(B.model.state_dict(), args.run_path + 'agentB.pth')
    display_losses(A.loss_history, B.loss_history, save_path=args.run_path + 'loss.png')
    print('Messages of A:')
    print(A.messages)
    print('Messages of B:')
    print(B.messages)
    save_toFile(path=args.run_path, file_name='a_messages', data_saved=A.messages, rows=1)
    save_toFile(path=args.run_path, file_name='b_messages', data_saved=B.messages, rows=1)
    save_toFile(path=args.run_path, file_name='a_latents', data_saved=A.latents, rows=1)
    save_toFile(path=args.run_path, file_name='b_latents', data_saved=B.latents, rows=1)
    save_toFile(path=args.run_path, file_name='a_latents_recon', data_saved=A.latents_recon, rows=1)
    save_toFile(path=args.run_path, file_name='b_latents_recon', data_saved=B.latents_recon, rows=1)


if __name__ == "__main__":
    main()
