import sys
sys.path.append('../')

import os
from torch.utils.data.sampler import SubsetRandomSampler, RandomSampler
from torch.utils.tensorboard import SummaryWriter
import random
from transformers.tokenization_bert import BertTokenizer
from _collections import defaultdict
random.seed(0)
from random import shuffle
import json
import copy
from prettytable import PrettyTable
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from multiprocessing import Pool
import torch
torch.manual_seed(0)
from torch import nn
import time
from emb2emb.utils import get_data, pretty_print_prediction, word_index_mapping, Namespace,\
    read_all
from emb2emb.train import get_encoder
from emb2emb.analyze_l0drop import compute_neighborhood_preservation
import argparse
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from autoencoder import AutoEncoder, Encoder, Decoder
from rnn_decoder import RNNDecoder
from rnn_encoder import RNNEncoder
from bert_encoder import BERTEncoder
from pretrained_encoder import PretrainedEncoder
from bow_encoder import BoWEncoder
from cmow_encoder import CMOWEncoder
from concat_encoder import ConcatEncoder
from transformer_encoder import TransformerEncoder
from transformer_decoder import TransformerDecoder
from transformer_decoder_simple import SimpleTransformerDecoder
import numpy as np
np.random.seed(0)
from tqdm import tqdm
from data_loaders import HDF5Dataset, get_tokenizer, TOKENIZER_LIST

DEFAULT_CONFIG = "config/default.json"
LOG_DIR_NAME = "logs/"
# if set to true, the logs of multiple runs of the trainer will be logged
# seperately.
LOG_SEPERATE = False


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("config", type=str, default="config/default.json",
                        help="The config file specifying all params.")
    params = parser.parse_args()
    with open(DEFAULT_CONFIG) as f:
        config = json.load(f)
    with open(params.config) as f:
        config.update(json.load(f))
    n = Namespace()
    n.__dict__.update(config)
    return n

# Won't be both adversarial and variational


def train_batch(model, optimizer, X, X_lens, lambda_r=1, lambda_kl=1, lambda_a=1, update=True, gradient_accumulation=1, clip_gradient=False, max_norm=2.0):
    # Train autoencoder
    model.train()
    output = model(X, X_lens)
    if model.variational:
        output, mu, z, embeddings = output
        losses = model.loss_variational(
            output, embeddings, X, mu, z, lambda_r, lambda_kl)
    elif model.adversarial:
        output, fake_z_g, fake_z_d, true_z, embeddings = output
        losses = model.loss_adversarial(
            output, embeddings, X, fake_z_g, fake_z_d, true_z, lambda_a)

        # update the discriminator independently
        model.optimD.zero_grad()
        losses['d_loss'].backward()
        model.optimD.step()
    elif model.use_l0drop:
        output, embeddings, l0_loss = output
        losses = model.loss_l0drop(
            output, embeddings, X, l0_loss)
    elif model.act:
        output, embeddings, act_cost = output
        losses = model.loss_act(
            output, embeddings, X, act_cost)
    else:
        predictions, embeddings = output
        losses = model.loss(predictions, embeddings, X)

    loss = losses["loss"]
    loss = loss / gradient_accumulation
    loss.backward()
    if clip_gradient:
        nn.utils.clip_grad_norm_(
            model.parameters(), max_norm=max_norm, norm_type=2)
    loss = loss * gradient_accumulation
    if update:
        optimizer.step()
        optimizer.zero_grad()

    return losses


def test_batch(model, X, X_lens, lambda_r=1, lambda_kl=1, lambda_a=1):
    with torch.no_grad():
        output = model(X, X_lens)
        if model.variational:
            output, mu, z, embeddings = output
            losses = model.loss_variational(
                output, embeddings, X, mu, z, lambda_r, lambda_kl)
        elif model.adversarial:
            output, fake_z_g, fake_z_d, true_z, embeddings = output
            losses = model.loss_adversarial(
                output, embeddings, X, fake_z_g, fake_z_d, true_z, lambda_a)
        elif model.use_l0drop:
            output, embeddings, l0_loss = output
            losses = model.loss_l0drop(
                output, embeddings, X, l0_loss)
        elif model.act:
            output, embeddings, act_cost = output
            losses = model.loss_act(
                output, embeddings, X, act_cost)
        else:
            p, e = output
            losses = model.loss(p, e, X)
        return losses


def prepare_batch(indexed, lengths, device, sort=False):
    X = pad_sequence([index_list.to(device)
                      for index_list in indexed], batch_first=True, padding_value=0)

    X = X[:, :lengths.max()]
    lengths = lengths.to(device)
    if sort:
        lengths, idx = torch.sort(lengths.to(device), descending=True)
        X = X[idx]
    return X, lengths


def evaluate(data, device, batch_size, lambda_r=1, lambda_kl=1, lambda_a=1):
    model.train(mode=False)
    aggregated_losses = defaultdict(list)
    for data_b, lens_b in data:
        X_valid, X_valid_lens = prepare_batch(data_b, lens_b, device)
        valid_loss = test_batch(
            model, X_valid, X_valid_lens, lambda_r, lambda_kl, lambda_a)

        for k, v in valid_loss.items():
            aggregated_losses[k].append(v.cpu().detach().numpy().item())

    losses = {}
    for k, v in aggregated_losses.items():
        losses[k] = np.array(v).mean(axis=0)
    return losses


def eval(model, X, X_lens, noise, device):
    encoded = model.encode(X, X_lens)
    if noise != 0.0:
        encoded += torch.randn_like(encoded, device=device) * noise
    return (model.beam_decode(encoded), model.greedy_decode(encoded))


def evaluate_sentence(model, data, device, tokenizer, config):
    if model.training:
        model.train(mode=False)
    orig_sentences = []
    greedy_sentences = []
    #beam_sentences = []
    with torch.no_grad():
        for i, (data_b, lens_b) in enumerate(data):
            X, X_lens = prepare_batch(data_b, lens_b, device)
            encoded = model.encode(X, X_lens)
            encoding_length = encoded[1].float().mean()
            # greedy, beam = model.decode(
            #    encoded), model.decode(encoded, beam_width=10)
            greedy = model.decode(encoded)
            if type(tokenizer) in [BertTokenizer]:
                s = {"original": tokenizer.decode(X[0]),
                     "greedy": tokenizer.decode(greedy[0]),
                     "beam": "none"}
            else:
                s = {"original": tokenizer.decode(X[0].tolist()),
                     "greedy": tokenizer.decode(greedy[0]),
                     "beam": "None"}

                if config.compute_bleu > 0:
                    orig_sentences.extend(
                        [tokenizer.decode(x.tolist()) for x in X])
                    greedy_sentences.extend(
                        [tokenizer.decode(g) for g in greedy])
                    # beam_sentences.extend(
                    #    [tokenizer.decode(b) for b in beam])

            if i == config.compute_bleu or config.compute_bleu == -1:
                break

    if config.compute_bleu > 0:
        orig_sentences = [[x.split()] for x in orig_sentences]
        greedy_sentences = [x.split() for x in greedy_sentences]
        #beam_sentences = [x.split() for x in beam_sentences]

        greedy_score = corpus_bleu(orig_sentences, greedy_sentences,
                                   smoothing_function=SmoothingFunction().method1)
        # beam_score = corpus_bleu(orig_sentences, beam_sentences,
        #                         smoothing_function=SmoothingFunction().method1)
    else:
        greedy_score = 0
        #beam_score = 0
    return s, greedy_score, encoding_length  # , beam_score


def get_model_info(config):
    model_info = copy.deepcopy(config.__dict__)
    for key in list(model_info):
        if isinstance(model_info[key], dict):
            if key == config.encoder:
                e_info = model_info[key]
                for kk in e_info:
                    model_info["e_" + kk] = e_info[kk]
            elif key == config.decoder:
                d_info = model_info[key]
                for kk in d_info:
                    model_info["d_" + kk] = d_info[kk]
            del model_info[key]
        if key in model_info and isinstance(model_info[key], list):
            if key == config.encoder:
                e_info = model_info[key]
                for kk in e_info:
                    model_info["e_" + kk] = kk
            del model_info[key]
    return model_info


def inverse_sqrt_schedule_w_warmup(step_num, warmup_steps, lr):
    step_num += 1  # To start w/ non-zero
    decay_factor = warmup_steps**0.5
    if step_num < warmup_steps:
        return step_num / warmup_steps
    else:
        return decay_factor * step_num**-0.5
    return
    # return min(step_num**-0.5, step_num*warmup_steps**-1.5)


def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params += param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params


if __name__ == "__main__":
    # TODO: LR Annealing
    config = parse_args()

    original_config = copy.deepcopy(config)

    print(json.dumps(config.__dict__, indent=4))

    model_info = get_model_info(config)

    device = torch.device(
        config.device if torch.cuda.is_available() else "cpu")
    print(device)

    # TODO: Generalize to use pretrained tokenizers
    if config.encoder == "BERTEncoder":
        tokenizer = get_tokenizer(
            config.tokenizer, config.BERTEncoder["bert_location"])
        config.__dict__["vocab_size"] = tokenizer.vocab_size
        config.__dict__["sos_idx"] = tokenizer.cls_token_id
        config.__dict__["eos_idx"] = tokenizer.sep_token_id
        config.__dict__["unk_idx"] = tokenizer.unk_token_id
        config.__dict__["pad_idx"] = tokenizer.pad_token_id
    else:
        tokenizer = get_tokenizer(config.tokenizer, config.tokenizer_location)
        config.__dict__["vocab_size"] = tokenizer.get_vocab_size()
        config.__dict__["sos_idx"] = tokenizer.token_to_id("<SOS>")
        config.__dict__["eos_idx"] = tokenizer.token_to_id("<EOS>")
        config.__dict__["unk_idx"] = tokenizer.token_to_id("<unk>")
        config.__dict__["pad_idx"] = tokenizer.token_to_id("[PAD]")

    config.__dict__["device"] = device

    # parse encoder
    encoder_config = copy.deepcopy(config)

    def instantiate_encoder(c):
        print(c.encoder)
        if c.encoder == "RNNEncoder":
            encoder = RNNEncoder(c)
        elif c.encoder == "BERTEncoder":
            encoder = BERTEncoder(c)
        elif c.encoder == "PretrainedEncoder":
            encoder = PretrainedEncoder(c)
        elif c.encoder == "TransformerEncoder":
            encoder = TransformerEncoder(c)
        elif c.encoder == "BoWEncoder":
            encoder = BoWEncoder(c)
        elif c.encoder == "CMOWEncoder":
            encoder = CMOWEncoder(c)
        return encoder

    if type(config.encoder) == list:
        # we have multiple encoders

        encoder_configs = config.__dict__["ConcatEncoder"]

        encoder_list = []
        for i in range(len(config.encoder)):
            enc_c = copy.deepcopy(config)
            enc_c.__dict__.update(encoder_configs[i])
            enc_c.encoder = config.encoder[i]
            encoder_list.append(instantiate_encoder(enc_c))
        encoder = ConcatEncoder(config, encoder_list)
    else:
        encoder_config.__dict__.update(config.__dict__[config.encoder])
        encoder = instantiate_encoder(encoder_config)
    encoder_config.__dict__["tokenizer"] = tokenizer

    # parse decoder
    decoder_config = copy.deepcopy(config)
    decoder_config.__dict__.update(config.__dict__[config.decoder])
    if config.decoder == "RNNDecoder":
        decoder = RNNDecoder(decoder_config)
    if config.decoder == "TransformerDecoder":
        decoder = TransformerDecoder(decoder_config)
    if config.decoder == "SimpleTransformerDecoder":
        decoder = SimpleTransformerDecoder(decoder_config)

    model = AutoEncoder(encoder, decoder, tokenizer, config)
    count_parameters(model)

    model_path = os.path.join(config.savedir, config.model_file)
    if os.path.isfile(model_path):
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        if "best_model_state_dict" in checkpoint:
            best_model_state_dict = checkpoint['best_model_state_dict']
        else:
            best_model_state_dict = None
    else:
        checkpoint = None
        best_model_state_dict = None
    if checkpoint is not None and 'iteration' in checkpoint:
        i = checkpoint['iteration']
    else:
        i = 0

    model.to(device)
    model.train()

    if config.adversarial:
        model_parameters = []
        for name, param in model.named_parameters():
            if name.startswith("encoder") or name.startswith("decoder"):
                model_parameters.append(param)
            elif name.startswith("discriminator"):
                pass
            else:
                raise AssertionError(
                    "Found a model parameter " + name + " that we do not know how to handle.")
    else:
        model_parameters = model.parameters()
    optimizer = optim.Adam(model_parameters, lr=config.lr)
    if os.path.isfile(model_path) and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    if config.lr_scheduler == "inverse_sqrt":
        def lr_multiplier(step):
            return inverse_sqrt_schedule_w_warmup(step, config.warmup_steps, config.lr)
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_multiplier)
    else:
        lr_scheduler = None
    print(model)

    logdir = os.path.join(config.savedir, LOG_DIR_NAME)

    if LOG_SEPERATE:
        log_num = 0
        while os.path.isdir(os.path.join(logdir, str(log_num))):
            log_num += 1
        logdir = os.path.join(logdir, str(log_num))

    def collate_batches(batch):
        """
        'batch' is a list of pairs (X, X_len) which are of size [batch_size, max_len] and [batch_size], respectively.
        The default collate_batches would create a new dimension, but we want to stack alongside the batch_dimension.
        """
        Xs, X_lens = zip(*batch)
        X = torch.cat(Xs, dim=0)
        X_len = torch.cat(X_lens, dim=0)

        return X, X_len

    dataset = HDF5Dataset(config.dataset_path, False, False,
                          data_cache_size=3, transform=None)

    if config.val_dataset_path != "unknown":
        val_dataset = HDF5Dataset(config.val_dataset_path, False, False,
                                  data_cache_size=3, transform=None)

        # make sure we shuffle the training dataset so that we dont observe the
        # same order every time the process crashes and has to be restarted
        torch.manual_seed(i)
        trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, sampler=RandomSampler(dataset),
                                                  num_workers=config.workers, collate_fn=collate_batches)
        valloader = torch.utils.data.DataLoader(val_dataset, batch_size=1,
                                                sampler=RandomSampler(
                                                    val_dataset),
                                                num_workers=config.workers, collate_fn=collate_batches)
    else:
        indices = list(range(len(dataset)))
        shuffle(indices)
        num_val_samples = int(len(indices) * config.valsize)
        train_indices = indices[:-num_val_samples]
        val_indices = indices[-num_val_samples:]

        # downsample if appropriate
        train_indices = train_indices[:int(
            config.data_fraction * len(train_indices))]
        val_indices = val_indices[:int(
            config.data_fraction * len(val_indices))]

        # make sure we shuffle the training dataset so that we dont observe the
        # same order every time the process crashes and has to be restarted
        torch.manual_seed(i)
        trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, sampler=SubsetRandomSampler(
            train_indices), num_workers=config.workers, collate_fn=collate_batches)
        valloader = torch.utils.data.DataLoader(dataset, batch_size=1, sampler=SubsetRandomSampler(
            val_indices), num_workers=config.workers, collate_fn=collate_batches)

    epoch = 0
    levenshtein_precomputed = None
    time_s = time.time()
    min_val_loss = checkpoint['min_val_loss'] if checkpoint is not None and 'min_val_loss' in checkpoint else float(
        'inf')
    stop_training = False

    epoch_batches = len(trainloader)
    print(f"Epoch batches: {epoch_batches}")
    with SummaryWriter(log_dir=logdir) as sw:
        sw.add_hparams(model_info, {})

        if os.path.isfile(model_path):
            print("Running initial validation step.")
            val_losses = evaluate(
                valloader, device, 1, config.lambda_r, config.lambda_kl, config.lambda_a)
            for k, v in val_losses.items():
                sw.add_scalar("Validation/" + k, v, i)

            min_val_loss = val_losses["loss"]

        print("Starting training")
        model.train()
        skip = True
        print("Number of batches", len(trainloader))

        cur_i = 0
        while not stop_training:
            pbar = tqdm(
                trainloader, desc=f"[E{epoch}, B{i % len(trainloader)}]")
            for s_batch, l_batch in pbar:

                epoch = int(cur_i / len(trainloader))

                # skip batches we have already seen
                if skip and cur_i < i:
                    cur_i += 1
                    continue
                elif skip:
                    skip = False
                    print(f"\nEnd skipping at {cur_i}\n")

                i += 1

                # Train on batch
                lambda_kl = 0 if epoch < config.kl_delay else config.lambda_kl if epoch > config.kl_delay else config.lambda_kl * \
                    ((i - epoch_batches * epoch) / epoch_batches)
                X, X_lens = prepare_batch(s_batch, l_batch, device)
                actual_batch_size = X_lens.size(0)
                update = (i % config.gradient_accumulation) == 0
                losses = train_batch(model, optimizer, X, X_lens, lambda_r=config.lambda_r, lambda_kl=lambda_kl,
                                     lambda_a=config.lambda_a, update=update, gradient_accumulation=config.gradient_accumulation,
                                     clip_gradient=config.clip_gradient, max_norm=config.gradient_clipping_max_norm)

                # if config.variational:
                #loss, r_loss, kl_loss = loss
                # if config.adversarial:
                #loss, r_loss, d_loss, g_loss = loss
                if lr_scheduler is not None and update:
                    lr_scheduler.step()

                msg = ""
                if i % (config.print_frequency) == 0:
                    #sw.add_scalar("Loss/Train", losses["loss"].cpu().item(), i)
                    l = losses["loss"]
                    msg = f"[E{epoch}, B{i}] tr={l:0.2f}"
                    msg += f", val={min_val_loss:0.2f}"

                    for k, v in losses.items():
                        msg += f", {k}={v:0.2f}"
                        sw.add_scalar("Train/" + k, v.cpu().item(), i)

                    speed = ((config.print_frequency *
                              actual_batch_size) // (time.time() - time_s))
                    time_s = time.time()
                    sw.add_scalar("Speed/Speed", speed, i)
                    # print(msg)
                    pbar.set_description(msg)
                    sw.flush()

                # Validation
                do_val = (i % config.validation_frequency) == 0
                do_checkpoint = (i % config.checkpoint_frequency) == 0
                do_continuous_checkpoint = (
                    i % config.continuous_checkpoint_frequency) == 0
                if do_val or do_checkpoint or do_continuous_checkpoint:

                    if do_val:
                        val_losses = evaluate(
                            valloader, device, 1, config.lambda_r, config.lambda_kl, config.lambda_a)
                        val_loss = val_losses["loss"]
                        for k, v in val_losses.items():
                            sw.add_scalar("Validation/" + k, v, i)

                        if i % (config.validation_frequency * config.sentence_eval_frequency) == 0:
                            es, greedy_score, encoding_length = evaluate_sentence(
                                model, valloader, device, tokenizer, config)
                            sw.add_text("Validation/Original",
                                        es["original"], i)
                            sw.add_text("Validation/Greedy", es["greedy"], i)
                            #sw.add_text("Validation/Beam", es["beam"], i)
                            sw.add_scalar(
                                "Validation/GreedyBleu", greedy_score, i)
                            sw.add_scalar(
                                "Validation/EncodingLength", encoding_length, i)
                            #sw.add_scalar("Validation/BeamBleu", beam_score, i)

                            es, greedy_score, encoding_length = evaluate_sentence(
                                model, trainloader, device, tokenizer, config)
                            sw.add_text("Train/Original", es["original"], i)
                            sw.add_text("Train/Greedy", es["greedy"], i)
                            #sw.add_text("Train/Beam", es["beam"], i)
                            sw.add_scalar("Train/GreedyBleu", greedy_score, i)
                            sw.add_scalar(
                                "Train/EncodingLength", encoding_length, i)
                            #sw.add_scalar("Train/BeamBleu", beam_score, i)

                        if val_loss < min_val_loss:
                            min_val_loss = val_loss
                            best_model_state_dict = model.state_dict()

                    # save the checkpoint
                    checkpoint = {"model_state_dict": model.state_dict(),
                                  "best_model_state_dict": best_model_state_dict,
                                  "optimizer_state_dict": optimizer.state_dict(),
                                  "iteration": i,
                                  'min_val_loss': min_val_loss}

                    def checkpoint_by_path(p):
                        os.makedirs(p, exist_ok=True)
                        torch.save(checkpoint, os.path.join(
                            p, config.model_file))
                        with open(os.path.join(p, 'config.json'), 'w') as f:
                            json.dump(original_config.__dict__, f)

                    if config.checkpoint_all and do_checkpoint:
                        checkpoint_by_path(config.savedir + str(i))
                    if do_val or do_continuous_checkpoint:
                        checkpoint_by_path(config.savedir)

                    if do_val and config.eval_neighborhood_preservation:
                        params = Namespace(modeldir=config.savedir,
                                           remove_sos_and_eos=False,
                                           max_input_length=999,
                                           batch_size=64,
                                           autoencoder="file")

                        valid = {"Sx": read_all(
                            config.neighborhood_preservation_path), "Sy": []}  # Sy is left empty on purpose, all data is in Sx
                        ae_enc = get_encoder(params, device).to(device)
                        recall_vals, levenshtein_precomputed = compute_neighborhood_preservation(
                            ae_enc, valid, params, return_levenshtein=True, levenshtein_precomputed=levenshtein_precomputed)
                        sw.add_scalar(
                            "Validation/NeighborhoodPreservation", recall_vals[1], i)

                    # return to training mode
                    model.train()

                if i >= config.max_steps:
                    stop_training = True
                    break
                cur_i += 1

            #epoch += 1

    # save the final model
    os.makedirs(config.savedir, exist_ok=True)
    checkpoint = {"model_state_dict": best_model_state_dict}
    torch.save(checkpoint, model_path)
    with open(os.path.join(config.savedir, 'config.json'), 'w') as f:
        json.dump(original_config.__dict__, f)

    print("<<<JOB_FINISHED>>>")
