# coding=utf-8
import csv
import argparse
from argparse import ArgumentParser
import os
import numpy as np
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Dataset
from tqdm import tqdm, trange
import json

from tokenizer_custom import ClassifTokenizer

try:
    from transformers import (get_linear_schedule_with_warmup)
except:
    from transformers import WarmupLinearSchedule as get_linear_schedule_with_warmup
from models import *
from tensorboardX import SummaryWriter
from transformers import BertTokenizer
import copy
from transformers import *
from models_transfert_style import *

logger = logging.getLogger(__name__)


class TextDataset(Dataset):
    def __init__(self, args, dev, reny_ds=False):
        logger.info("Loading dataset {}".format('Validation' if dev else 'Train'))

        def open_classif(path):
            with open(path, 'r') as file:
                lines = file.readlines()
            lines = [line.replace('\n', '') for line in lines]
            texts = []
            label_downstream = []
            label_protected = []
            for line in lines:
                text = [int(w_id) for w_id in line.split('\t')[0].split(' ')]
                text += (args.max_length - len(text)) * [args.padding_idx]
                texts.append(text)
                label_downstream.append(int(line.split('\t')[1]))
                label_protected.append(int(line.split('\t')[2]))
            return texts, label_downstream, label_protected

        file_name = 'x_val' if dev else 'x_train'
        if args.use_mention:
            file_path = os.path.join('processed_mention_splitted', file_name)
        else:
            file_path = os.path.join('processed_sentiment_splitted', file_name)
        self.lines, self.label_downstream, self.label = open_classif(
            os.path.join('data/classification', file_path))

    def __len__(self):
        return len(self.label)

    def __getitem__(self, item):
        return {'line': torch.tensor(self.lines[item], dtype=torch.long),
                'label': torch.tensor(self.label[item], dtype=torch.long),
                'downstream_labels': torch.tensor(self.label_downstream[item], dtype=torch.long)
                }


class ClassifierDataset(Dataset):
    # TODO : change everything here.
    def __init__(self, args, test_dataset, model):
        self.l_embeddings = []
        self.l_labels = []
        eval_sampler = SequentialSampler(test_dataset)
        eval_dataloader = DataLoader(
            test_dataset, sampler=eval_sampler, batch_size=args.batch_size, drop_last=True)

        # Eval!
        logger.info("***** Embedding evaluation *****")
        logger.info("  Batch size = %d", args.batch_size)
        model.eval()
        for batch in tqdm(eval_dataloader, desc="Converting The Dataset"):
            inputs = batch['line'].to(args.device)
            labels = batch['label'].to(args.device)
            with torch.no_grad():
                embeddings = model.predict_latent_space(inputs)
            self.l_embeddings.append(embeddings.cpu().detach())
            self.l_labels.append(labels.cpu())

        self.l_embeddings = torch.cat(self.l_embeddings, dim=1).permute(1, 0, 2).tolist()
        self.l_embeddings = [torch.tensor(i) for i in self.l_embeddings]
        self.l_labels = torch.cat(self.l_labels, dim=0).tolist()
        self.l_labels = [torch.tensor(i) for i in self.l_labels]

    def __len__(self):
        return len(self.l_labels)

    def __getitem__(self, item):
        return {'line': self.l_embeddings[item].float(),
                'label': self.l_labels[item].long()}


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)


def train_classifer(args, classifier, train_dataset, dev_dataset, model):
    suffix = args.output_dir
    tb_writer = SummaryWriter('runs/{}'.format(suffix))
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(
        train_dataset, sampler=train_sampler, batch_size=args.batch_size, drop_last=True)

    dev_sampler = RandomSampler(dev_dataset)
    dev_dataloader = DataLoader(dev_dataset, sampler=dev_sampler, batch_size=args.batch_size, drop_last=True)

    t_total = len(train_dataloader) * args.num_train_epochs
    optimizer = AdamW(classifier.parameters(), lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(optimizer, args.warmup_steps, t_total)
    loss_fct = torch.nn.NLLLoss()

    args.scheduler = scheduler
    args.optimizer = optimizer
    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size = %f", args.batch_size)
    logger.info("  Total optimization steps = %f", t_total)

    best_loss = 100000000
    global_step = 0
    epochs_trained = 0
    classifier.zero_grad()
    train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch")
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):
            inputs = batch['line'].to(args.device)  # .permute(1, 0, 2)  # TODO : do model prediction here :)
            with torch.no_grad():
                inputs = model.predict_latent_space(inputs)
            labels = batch['label'].to(args.device)
            classifier.train()
            prediction = classifier(inputs)
            loss = loss_fct(prediction, labels.long())
            loss.backward()
            torch.nn.utils.clip_grad_norm_(classifier.parameters(), 1.0)
            optimizer.step()
            classifier.zero_grad()

            tb_writer.add_scalar("train_loss", loss, global_step)
            tb_writer.add_scalar("train_lr", scheduler.get_lr()[0], global_step)

            global_step += 1
            if global_step % args.save_step == 0:
                dev_epoch_iterator = tqdm(dev_dataloader, desc="Dev Iteration")
                dev_loss = 0
                for dev_step, dev_batch in enumerate(dev_epoch_iterator):
                    inputs = dev_batch['line'].to(args.device)
                    with torch.no_grad():
                        inputs = model.predict_latent_space(inputs)
                    labels = dev_batch['label'].to(args.device)
                    classifier.eval()
                    prediction = classifier(inputs)
                    dev_loss += loss_fct(prediction, labels.long())
                if best_loss > dev_loss:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir)
                    os.makedirs(output_dir, exist_ok=True)
                    classifier_path = os.path.join(output_dir, 'classifier_latent_space.pt')
                    torch.save(classifier.state_dict(), classifier_path)
                    with open(os.path.join(output_dir, 'training_args.txt'), 'w') as f:
                        dict_to_save = copy.copy(args.__dict__)
                        for key, value in dict_to_save.items():
                            if value is None:
                                pass
                            elif isinstance(value, (bool, int, float)):
                                pass
                            elif isinstance(value, (tuple, list)):
                                pass
                            elif isinstance(value, dict):
                                pass
                            else:
                                dict_to_save[key] = 0
                        json.dump(dict_to_save, f, indent=2)
                    logger.info("Saving model checkpoint to %s", output_dir)
    # Load last best classifier :)
    logger.info("Last checkpoint %s", global_step)
    logger.info("Reloading Best Saved model at %s", classifier_path)
    classifier.load_state_dict(torch.load(classifier_path, map_location=torch.device(args.device)))


def evaluate_disantanglement(args, classifer, eval_dataset, model):
    loss_fct = torch.nn.NLLLoss()
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, drop_last=True)

    # Eval!
    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.batch_size)
    classifer.eval()
    losses = []
    accuracies = []
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        inputs = batch['line'].to(args.device)
        with torch.no_grad():
            inputs = model.predict_latent_space(inputs)
        labels = batch['label'].to(args.device)
        with torch.no_grad():
            prediction = classifer(inputs)
        loss = loss_fct(prediction, labels.long())
        losses.append(loss.item())
        accuracy = sum([i == j for i, j in zip(prediction.topk(1)[-1].squeeze(-1).tolist(), labels.tolist())]) / len(
            labels.tolist())
        accuracies.append(accuracy)
    loss = sum(losses) / len(losses)
    accuracy = sum(accuracies) / len(accuracies)
    logger.info("***** loss evaluation {} *****".format(loss))
    logger.info("***** accuracy evaluation {} *****".format(accuracy))
    f = open(os.path.join('results_classif', args.suffix, 'disantanglement.txt'), "w")
    f.write('Evaluation for disantanglement latent space: \n')
    f.write('accuracy\t:{}\n'.format(accuracy))
    f.write('loss\t:{}\n'.format(loss))


def compute_disparate_impact(args, eval_dataset, model):
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, drop_last=True)

    # Eval!
    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %5d", len(eval_dataset))
    logger.info("  Batch size = %d", args.batch_size)
    nb_eval_steps = 0

    model.eval()
    predictions = []
    losses = []
    labels = []
    label_protected = []
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        inputs = batch['line'].to(args.device)
        downstream_labels = batch['downstream_labels'].to(args.device)
        labels_protected = batch['label'].to(args.device)

        with torch.no_grad():
            model = model.eval()
            outputs = model.predict_downstream(inputs)

            predictions += torch.exp(outputs[:, 0]).tolist()
            labels += downstream_labels.tolist()
            label_protected += labels_protected.tolist()
            nb_eval_steps += 1
    accuracies = [1 if i == j else 0 for i, j in zip(predictions, labels)]
    accuracy = sum(accuracies) / len(accuracies)
    DI = sum([i if j == 1 else 0 for i, j in zip(predictions, label_protected)]) / sum(
        [i if j == 0 else 0 for i, j in zip(predictions, label_protected)])
    f = open(os.path.join('results_classif', args.suffix, 'classification_report_0.txt'), "w")
    f.write('Evaluation for disantanglement latent space: \n')
    f.write('accuracy\t:{}\n'.format(accuracy))
    f.write('DI\t:{}\n'.format(DI))


def test_classif(args, eval_dataset, model):
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, drop_last=True)

    # Eval!
    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %5d", len(eval_dataset))
    logger.info("  Batch size = %d", args.batch_size)
    nb_eval_steps = 0

    model.eval()
    predictions = []
    losses = []
    labels = []
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        inputs = batch['line'].to(args.device)
        downstream_labels = batch['downstream_labels'].to(args.device)

        with torch.no_grad():
            model = model.eval()
            outputs = model.predict_downstream(inputs)
            losses.append(torch.nn.NLLLoss()(outputs, downstream_labels).item())

            predictions += outputs.topk(1, dim=-1)[1].squeeze(-1).tolist()
            labels += downstream_labels.tolist()
            nb_eval_steps += 1
    loss = sum(losses) / len(losses)
    accuracies = [1 if i == j else 0 for i, j in zip(predictions, labels)]
    accuracy = sum(accuracies) / len(accuracies)
    f = open(os.path.join('results_classif', args.suffix, 'classification_report_di.txt'), "w")
    f.write('Evaluation for disantanglement latent space: \n')
    f.write('accuracy\t:{}\n'.format(accuracy))
    f.write('loss\t:{}\n'.format(loss))


def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--suffix", default='yelp', type=str, help="The input training data file (a text file).")
    parser.add_argument("--output_dir", default='debug',
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--batch_size", default=14, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
    parser.add_argument("--save_step", type=int, default=1000, help="random seed for initialization")
    parser.add_argument("--num_train_epochs", type=int, default=1000, help="random seed for initialization")
    parser.add_argument("--warmup_steps", type=int, default=0, help="random seed for initialization")
    parser.add_argument("--learning_rate", type=float, default=0.001, help="random seed for initialization")
    parser.add_argument("--adam_epsilon", type=float, default=0.001, help="random seed for initialization")

    # Architecture Seq2Seq
    parser.add_argument("--model_path_to_load", default='checkpoint-70000',
                        help="random seed for initialization")

    # Architecture
    """
    parser.add_argument("--style_dim", type=int, default=8, help="random seed for initialization")
    parser.add_argument("--content_dim", type=int, default=256, help="random seed for initialization")
    parser.add_argument("--number_of_layers", type=int, default=2, help="random seed for initialization")
    parser.add_argument("--hidden_dim", type=int, default=256, help="random seed for initialization")
    parser.add_argument("--dec_hidden_dim", type=int, default=136, help="random seed for initialization")
    parser.add_argument("--dropout", type=float, default=0.5, help="random seed for initialization")
    parser.add_argument("--number_of_styles", type=int, default=2, help="random seed for initialization")
    parser.add_argument("--mul_style", type=int, default=10, help="random seed for initialization")
    parser.add_argument("--adv_style", type=int, default=1, help="random seed for initialization")
    parser.add_argument("--mul_mi", type=float, default=1, help="random seed for initialization")
    parser.add_argument("--alpha", type=float, default=1.5, help="random seed for initialization")
    parser.add_argument("--ema_beta", type=float, default=0.99, help="random seed for initialization")
    parser.add_argument("--not_use_ema", action="store_true", help="random seed for initialization")
    parser.add_argument("--no_reny", action="store_true", help="random seed for initialization")
    parser.add_argument("--reny_training", type=int, default=2, help="random seed for initialization")
    parser.add_argument("--number_of_training_encoder", type=int, default=1, help="random seed for initialization")
    parser.add_argument("--add_noise", action="store_true", help="random seed for initialization")
    parser.add_argument("--noise_p", type=float, default=0.1, help="random seed for initialization")
    parser.add_argument("--number_of_perm", type=int, default=3, help="random seed for initialization")
    parser.add_argument("--alternative_hs", action="store_true", help="random seed for initialization")
    parser.add_argument("--no_minimization_of_mi_training", action="store_true")
    parser.add_argument("--special_clement", action="store_true")
    parser.add_argument("--use_gender", action="store_true")
    parser.add_argument("--complex_proj_content", action="store_true")"""

    parser.add_argument("--use_complex_classifier", action='store_true', help="loading from path")
    parser.add_argument("--use_mention", action='store_true', help="loading from path")

    args = parser.parse_args()
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = ClassifTokenizer(
        'data/classification/processed_mention' if args.use_mention else 'data/classification/processed_sentiment')
    args.sos_token = tokenizer.sep_token_id
    args.max_length = 92
    args.number_of_tokens = tokenizer.vocab_size
    args.tokenizer = tokenizer
    args.use_complex_classifier = True
    args.padding_idx = tokenizer.pad_token_id
    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)
    # Set seed
    set_seed(args)
    os.makedirs(os.path.join('results_classif', args.suffix), exist_ok=True)
    args_model = ArgumentParser()

    with open(os.path.join(args.model_path_to_load, 'training_args.txt'), 'r') as f:
        args_model.__dict__ = json.load(f)

    args_model.sos_token = tokenizer.sep_token_id
    args_model.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args_model.number_of_tokens = tokenizer.vocab_size
    args_model.tokenizer = tokenizer
    args_model.batch_size = args.batch_size
    args_model.padding_idx = tokenizer.pad_token_id

    model = ClassificationStyleEmdedding(args_model, None)
    weight_pr = model.encoder.embedding.weight.data.tolist()
    model.load_state_dict(torch.load(os.path.join(args.model_path_to_load, 'model.pt'),
                                     map_location=torch.device(args.device)))
    assert weight_pr != model.encoder.embedding.weight.data.tolist()
    args_model.content_dim = args_model.hidden_dim
    model.to(args.device)
    model.eval()

    logger.info("Model parameters %s", args)
    dev_dataset = TextDataset(args, True)
    test_classif(args, dev_dataset, model)
    logger.info(" Test Over ")

    logger.info("Training Classifier")
    if True:
        compute_disparate_impact(args, dev_dataset, model)
    else:
        train_dataset = TextDataset(args, False)
        train_classifier_dataset = train_dataset  # ClassifierDataset(args, train_dataset, model)
        dev_classifier_dataset = dev_dataset  # ClassifierDataset(args, dev_dataset, model)
        classifier = Classifier(args_model.content_dim, args_model.number_of_styles, True).to(args.device)
        model.eval()
        train_classifer(args, classifier, train_classifier_dataset, dev_classifier_dataset, model)
        logger.info(" Training Over ")
        evaluate_disantanglement(args, classifier, dev_classifier_dataset, model)
        logger.info(" Program Over ")


if __name__ == "__main__":
    main()
