# coding=utf-8
import csv
import argparse
import logging
from argparse import ArgumentParser
import os
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Dataset
from tqdm import tqdm, trange
from transformers import AdamW

from metric import StyleTransfertMetric
import json

try:
    from transformers import (get_linear_schedule_with_warmup)
except:
    from transformers import WarmupLinearSchedule as get_linear_schedule_with_warmup
from models import *
from tensorboardX import SummaryWriter
from transformers import BertTokenizer
import copy

logger = logging.getLogger(__name__)


class TextDataset(Dataset):
    def __init__(self, args, dev):
        logger.info("Loading dataset {}".format('Validation' if dev else 'Train'))
        suffix = 'dev' if dev else 'train'

        try:
            with open(os.path.join('//DIT',
                                   'data/data_tensor_{}/tensor_sentiment.{}.1'.format(args.dataset, suffix)),
                      'r') as file:
                csv_reader = csv.reader(file, delimiter=',')
                lines_pos = [[int(index) for index in line] for line in csv_reader]
            with open(os.path.join('//DIT',
                                   'data/data_tensor_{}/tensor_sentiment.{}.0'.format(args.dataset, suffix)),
                      'r') as file:
                csv_reader = csv.reader(file, delimiter=',')
                lines_neg = [[int(index) for index in line] for line in csv_reader]
        except:
            with open('data/data_tensor_{}/tensor_sentiment.{}.1'.format(args.dataset, suffix),
                      'r') as file:
                csv_reader = csv.reader(file, delimiter=',')
                lines_pos = [[int(index) for index in line] for line in csv_reader]
            with open('data/data_tensor_{}/tensor_sentiment.{}.0'.format(args.dataset, suffix),
                      'r') as file:
                csv_reader = csv.reader(file, delimiter=',')
                lines_neg = [[int(index) for index in line] for line in csv_reader]

        labels = [1] * len(lines_pos) + [0] * len(lines_neg)
        lines = lines_pos + lines_neg
        random.seed(42)
        random.shuffle(labels)
        random.seed(42)
        random.shuffle(lines)
        self.lines = lines[:args.filter]
        self.label = labels[:args.filter]

    def __len__(self):
        return len(self.label)

    def __getitem__(self, item):
        return {'line': torch.tensor(self.lines[item], dtype=torch.long),
                'label': torch.tensor(self.label[item], dtype=torch.long)}


class ClassifierDataset(Dataset):
    def __init__(self, args, test_dataset, model):
        self.l_embeddings = []
        self.l_labels = []
        eval_sampler = SequentialSampler(test_dataset)
        eval_dataloader = DataLoader(
            test_dataset, sampler=eval_sampler, batch_size=args.batch_size, drop_last=True)

        # Eval!
        logger.info("***** Embedding evaluation *****")
        logger.info("  Batch size = %d", args.batch_size)
        model.eval()
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            inputs = batch['line'].to(args.device)
            labels = batch['label'].to(args.device)
            with torch.no_grad():
                embeddings = model.predict_latent_space(inputs)
            self.l_embeddings.append(embeddings.cpu().detach())
            self.l_labels.append(labels.cpu())

        self.l_embeddings = torch.cat(self.l_embeddings, dim=1).permute(1, 0, 2).tolist()
        self.l_embeddings = [torch.tensor(i) for i in self.l_embeddings]
        self.l_labels = torch.cat(self.l_labels, dim=0).tolist()
        self.l_labels = [torch.tensor(i) for i in self.l_labels]

    def __len__(self):
        return len(self.l_labels)

    def __getitem__(self, item):
        return {'line': self.l_embeddings[item].float(),
                'label': self.l_labels[item].long()}


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)


def train_classifer(args, classifier, train_dataset):
    suffix = args.output_dir

    tb_writer = SummaryWriter('runs/{}'.format(suffix))
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(
        train_dataset, sampler=train_sampler, batch_size=args.batch_size, drop_last=True)
    t_total = len(train_dataloader) * args.num_train_epochs
    optimizer = AdamW(classifier.parameters(), lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(optimizer, args.warmup_steps, t_total)
    loss_fct = torch.nn.NLLLoss()

    args.scheduler = scheduler
    args.optimizer = optimizer
    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size = %f", args.batch_size)
    logger.info("  Total optimization steps = %f", t_total)

    best_loss = 100000000
    global_step = 0
    epochs_trained = 0
    classifier.zero_grad()
    train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch")
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):
            inputs = batch['line'].to(args.device).permute(1, 0, 2)
            labels = batch['label'].to(args.device)
            classifier.train()
            prediction = classifier(inputs)
            loss = loss_fct(prediction, labels.long())
            print(loss.item())
            loss.backward()
            torch.nn.utils.clip_grad_norm_(classifier.parameters(), 1.0)
            optimizer.step()
            classifier.zero_grad()

            tb_writer.add_scalar("train_loss", loss, global_step)
            tb_writer.add_scalar("train_lr", scheduler.get_lr()[0], global_step)

            global_step += 1

            if global_step % args.save_step == 0:
                checkpoint_prefix = "checkpoint"
                if best_loss > loss:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    classifier_path = os.path.join(output_dir, 'classifier_latent_space.pt')
                    torch.save(classifier.state_dict(), classifier_path)
                    with open(os.path.join(output_dir, 'training_args.txt'), 'w') as f:
                        dict_to_save = copy.copy(args.__dict__)
                        for key, value in dict_to_save.items():
                            if value is None:
                                pass
                            elif isinstance(value, (bool, int, float)):
                                pass
                            elif isinstance(value, (tuple, list)):
                                pass
                            elif isinstance(value, dict):
                                pass
                            else:
                                dict_to_save[key] = 0
                        json.dump(dict_to_save, f, indent=2)
                    logger.info("Saving model checkpoint to %s", output_dir)
    # Load last best classifier :)
    logger.info("Last checkpoint %s", global_step)
    logger.info("Reloading Best Saved model at %s", classifier_path)
    classifier.load_state_dict(torch.load(classifier_path, map_location=torch.device(args.device)))


def evaluate_disantanglement(args, classifer, eval_dataset):
    loss_fct = torch.nn.NLLLoss()
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, drop_last=True)

    # Eval!
    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.batch_size)
    classifer.eval()
    losses = []
    accuracies = []
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        inputs = batch['line'].to(args.device).permute(1, 0, 2)
        labels = batch['label'].to(args.device)
        with torch.no_grad():
            prediction = classifer(inputs)
        loss = loss_fct(prediction, labels.long())
        losses.append(loss.item())
        accuracy = sum([i == j for i, j in zip(prediction.topk(1)[-1].squeeze(-1).tolist(), labels.tolist())]) / len(
            labels.tolist())
        accuracies.append(accuracy)
    loss = sum(losses) / len(losses)
    accuracy = sum(accuracies) / len(accuracies)
    logger.info("***** loss evaluation {} *****".format(loss))
    logger.info("***** accuracy evaluation {} *****".format(accuracy))
    f = open(os.path.join(args.output_dir, args.saving_result_file), "w")
    f.write('Evaluation for disantanglement latent space: \n')
    f.write('accuracy\t:{}\n'.format(accuracy))
    f.write('loss\t:{}\n'.format(loss))


def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--dataset", default='yelp', type=str, help="The input training data file (a text file).")
    parser.add_argument("--filter", default=25600, type=int, help="The input training data file (a text file).")
    parser.add_argument("--output_dir", default='debug',
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--batch_size", default=128, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument("--max_length", default=43, type=int, help="Linear warmup over warmup_steps.")
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
    parser.add_argument("--save_step", type=int, default=1000, help="random seed for initialization")
    parser.add_argument("--num_train_epochs", type=int, default=1000, help="random seed for initialization")
    parser.add_argument("--warmup_steps", type=int, default=0, help="random seed for initialization")
    parser.add_argument("--learning_rate", type=float, default=0.001, help="random seed for initialization")
    parser.add_argument("--adam_epsilon", type=float, default=0.001, help="random seed for initialization")

    # Architecture Seq2Seq
    parser.add_argument("--model", default='baseline_desantanglement_h', help="random seed for initialization")
    parser.add_argument("--path", default='models/vanilla_seq2seq_2', help="loading from path")  # vanilla_seq2seq_2
    parser.add_argument("--checkpoints", default='checkpoint-692000', help="loading from path")  # 660000
    parser.add_argument("--saving_result_file", default='test_desantaglement.txt', help="loading from path")

    # Classifier
    parser.add_argument("--path_classifier", default='classifier_vanilla_seq2seq_2_for_desantaglement_results',
                        help="loading from path")
    parser.add_argument("--checkpoints_path_classifier", default='checkpoint-173400', help="loading from path")
    # What to do
    parser.add_argument("--do_eval", action='store_true', help="loading from path")
    parser.add_argument("--use_complex_classifier", action='store_true', help="loading from path")
    parser.add_argument("--do_train_classifer", action='store_true', help="loading from path")

    # Architecture
    parser.add_argument("--style_dim", type=int, default=8, help="random seed for initialization")
    parser.add_argument("--content_dim", type=int, default=128, help="random seed for initialization")
    parser.add_argument("--number_of_layers", type=int, default=2, help="random seed for initialization")
    parser.add_argument("--hidden_dim", type=int, default=256, help="random seed for initialization")
    parser.add_argument("--dec_hidden_dim", type=int, default=136, help="random seed for initialization")
    parser.add_argument("--dropout", type=float, default=0.5, help="random seed for initialization")
    parser.add_argument("--number_of_styles", type=int, default=2, help="random seed for initialization")
    parser.add_argument("--mul_style", type=int, default=10, help="random seed for initialization")
    parser.add_argument("--adv_style", type=int, default=1, help="random seed for initialization")
    parser.add_argument("--mul_mi", type=float, default=1, help="random seed for initialization")
    parser.add_argument("--alpha", type=float, default=1.5, help="random seed for initialization")
    parser.add_argument("--ema_beta", type=float, default=0.99, help="random seed for initialization")
    parser.add_argument("--not_use_ema", action="store_true", help="random seed for initialization")
    parser.add_argument("--no_reny", action="store_true", help="random seed for initialization")
    parser.add_argument("--reny_training", type=int, default=2, help="random seed for initialization")
    parser.add_argument("--number_of_training_encoder", type=int, default=1, help="random seed for initialization")
    parser.add_argument("--add_noise", action="store_true", help="random seed for initialization")
    parser.add_argument("--noise_p", type=float, default=0.1, help="random seed for initialization")
    parser.add_argument("--number_of_perm", type=int, default=3, help="random seed for initialization")
    parser.add_argument("--alternative_hs", action="store_true", help="random seed for initialization")
    parser.add_argument("--no_minimization_of_mi_training", action="store_true")
    parser.add_argument("--special_clement", action="store_true")
    parser.add_argument("--complex_proj_content", action="store_true")
    parser.add_argument("--use_complex_classifier", action="store_true")

    args = parser.parse_args()
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    args.sos_token = tokenizer.sep_token_id
    args.number_of_tokens = tokenizer.vocab_size
    args.tokenizer = tokenizer
    args.padding_idx = tokenizer.pad_token_id
    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)
    # Set seed
    set_seed(args)

    args_model = ArgumentParser()
    # args_model = parser_model.parse_args()
    with open(os.path.join(args.path, args.checkpoints, 'training_args.txt'), 'r') as f:
        args_model.__dict__ = json.load(f)

    args_model.sos_token = tokenizer.sep_token_id
    args_model.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args_model.number_of_tokens = tokenizer.vocab_size
    args_model.tokenizer = tokenizer
    args_model.batch_size = args.batch_size
    args_model.padding_idx = tokenizer.pad_token_id

    logger.info("------------------------------------------ ")
    logger.info("model type = %s ", args.model)
    logger.info("------------------------------------------ ")
    if args.model == 'reny':
        model = RenySeq2Seq(args, None)
        weigths_non_loaded = model.proj_content[0].weight.tolist()
        model.load_state_dict(torch.load(os.path.join(args.path, args.checkpoints, 'model.pt'),
                                         map_location=torch.device(args.device)))
        assert weigths_non_loaded != model.proj_content[0].weight.tolist()
    elif args.model == 'baseline_desantanglement':
        model = BaselineDisentanglement(args_model)
        weigths_non_loaded = model.proj_content[0].weight.tolist()
        model.load_state_dict(torch.load(os.path.join(args.path, args.checkpoints, 'model.pt'),
                                         map_location=torch.device(args.device)))
        assert weigths_non_loaded != model.proj_content[0].weight.tolist()
    else:
        model = VanillaSeq2seq(args_model)
        weigths_non_loaded = model.decoder.out.weight.tolist()
        model.load_state_dict(torch.load(os.path.join(args.path, args.checkpoints, 'model.pt'),
                                         map_location=torch.device(args.device)))
        assert weigths_non_loaded != model.decoder.out.weight.tolist()
        args_model.content_dim = args_model.hidden_dim
    model.to(args.device)
    model.eval()

    logger.info("Model parameters %s", args)

    if args.do_train_classifer:
        logger.info("------------------------------------------ ")
        logger.info("Training Classifier")
        logger.info("------------------------------------------ ")

        train_dataset = TextDataset(args, False)
        classifier_dataset = ClassifierDataset(args, train_dataset, model)

        classifier = Classifier(args_model.content_dim, args_model.number_of_styles, args.use_complex_classifier).to(
            args.device)
        train_classifer(args, classifier, classifier_dataset)
        logger.info(" Training Over ")
        if args.do_eval:
            evaluate_disantanglement(args, classifier, classifier_dataset)
            logger.info(" Testing Over ")

    if args.do_eval and not args.do_train_classifer:
        logger.info("------------------------------------------ ")
        logger.info("Disantanglement Evaluation")
        logger.info("------------------------------------------ ")
        test_dataset = TextDataset(args, True)
        classifier_dataset = ClassifierDataset(args, test_dataset, model)
        classifer = Classifier(args_model.content_dim, args_model.number_of_styles,
                               use_complex_classifier=args.use_complex_classifier).to(args.device)
        weigths_non_loaded = classifer.net[0].weight.tolist()
        path_to_load = os.path.join(args.path_classifier, args.checkpoints_path_classifier,
                                    'classifier_latent_space.pt')
        classifer.load_state_dict(torch.load(path_to_load, map_location=torch.device(args.device)))

        assert weigths_non_loaded != classifer.net[0].weight.tolist()

        evaluate_disantanglement(args, classifer, classifier_dataset)

        logger.info(" Testing Over ")
    logger.info(" Program Over ")


if __name__ == "__main__":
    main()
