from explainers import InclusiveExplainer

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.pipeline import Pipeline

from customized_models.bert import BertForFactualCounterfactualSequenceClassification
from customized_models.roberta import RobertaForFactualCounterfactualSequenceClassification
from customized_models.gpt2 import GPT2ForFactualCounterfactualSequenceClassification
from customized_models.lstm.lstm import LSTMForFactualCounterfactualSequenceClassification

from utils.constants import BERT, ROBERTA, LSTM, GPT2

from transformers import (
    AutoTokenizer,
    Trainer,
    TrainingArguments
)
from transformers.trainer_utils import get_last_checkpoint
from transformers import EarlyStoppingCallback

from datasets import Dataset, load_metric

import numpy as np
import os
import json

import torch
from torch.nn.functional import softmax, log_softmax
from torch.nn import CrossEntropyLoss, KLDivLoss

from sklearn.metrics.pairwise import cosine_similarity

# TODO: save these somewhere else
TRAIN_FACTUAL_DISTILLATION = True
USE_TOKEN_FLAGS = True
# COUNTERFACTUAL_LOSS = 'CE'
COUNTERFACTUAL_LOSS = 'KLD'
# COUNTERFACTUAL_LOSS = 'COSINE'

# KDL loss with temperature
temperature = 2.0
kld_loss_function = KLDivLoss(reduction='batchmean')

def kld_loss_with_temperature(pred_logits, target_logits, temperature = temperature):
    return kld_loss_function(
        log_softmax(pred_logits / temperature, dim=-1),
        softmax(target_logits / temperature, dim=-1)
    ) * (temperature) ** 2

def calculate_loss_or_metrics(target_counterfactual_probs, target_factual_probs, pred_counterfactual_logits, pred_factual_logits, gt_counterfactual_labels, gt_factual_labels, return_loss = False):    
    # get factual and countefactual predicted probs
    pred_counterfactual_probs = softmax(pred_counterfactual_logits, dim=1)
    pred_factual_probs = softmax(pred_factual_logits, dim=1)

    # get true and predicted effects 
    target_effect = target_counterfactual_probs - target_factual_probs
    pred_effect = pred_counterfactual_probs - target_factual_probs

    if COUNTERFACTUAL_LOSS == 'CE':
        # CE loss expects input logits and target probs
        distillation_loss_function = CrossEntropyLoss()

        # counterfactual distillation loss
        counterfactual_loss = distillation_loss_function(pred_counterfactual_logits, target_counterfactual_probs)

        # factual distillation loss
        factual_loss = distillation_loss_function(pred_factual_logits, target_factual_probs)
    elif COUNTERFACTUAL_LOSS == 'KLD':
        # get target logits up to a constant (constant is irrelevant)
        target_counterfactual_logits = torch.log(target_counterfactual_probs)
        target_factual_logits = torch.log(target_factual_probs)

        # counterfactual distillation loss
        counterfactual_loss = kld_loss_with_temperature(pred_counterfactual_logits, target_counterfactual_logits)    
        
        # target distillation loss
        factual_loss = kld_loss_with_temperature(pred_factual_logits, target_factual_logits) 
    elif COUNTERFACTUAL_LOSS == 'COSINE':
        # CE loss for factual
        distillation_loss_function = CrossEntropyLoss()
        factual_loss = distillation_loss_function(pred_factual_logits, target_factual_probs)

        # COSINE loss for counterfactual
        counterfactual_loss = 1 - torch.nn.functional.cosine_similarity(pred_effect, target_effect, dim=1)
        counterfactual_loss = torch.mean(counterfactual_loss)


    # training loss
    loss = 3.0 * counterfactual_loss

    if TRAIN_FACTUAL_DISTILLATION == True:
        loss += 1.0 * factual_loss

    if return_loss == True:
        return loss

    # get gt labels
    gt_counterfactual_labels = np.argmax(gt_counterfactual_labels, axis=1)
    gt_factual_labels = np.argmax(gt_factual_labels, axis=1)

    # get pred labels
    pred_counterfactual_labels = np.argmax(pred_counterfactual_logits, axis=1) 
    pred_factual_labels = np.argmax(pred_factual_logits, axis=1)

    # get target labels
    target_counterfactual_labels = np.argmax(target_counterfactual_probs, axis=1)
    target_factual_labels = np.argmax(target_factual_probs, axis=1)

    # distillation_accuracy
    accuracy_metric = load_metric('accuracy')

    result_counterfactual = accuracy_metric.compute(predictions=pred_counterfactual_labels, references=target_counterfactual_labels)
    result_counterfactual = {f'distillation_{k}_counterfactual':v for k,v in result_counterfactual.items()}
    
    result = accuracy_metric.compute(predictions=pred_factual_labels, references=target_factual_labels)
    result = {f'distillation_{k}_factual':v for k,v in result.items()}

    result.update(result_counterfactual)

    # distillation_f1
    f1_metric = load_metric('f1')

    result_counterfactual = f1_metric.compute(predictions=pred_counterfactual_labels, references=target_counterfactual_labels, average='macro')
    result_counterfactual = {f'distillation_{k}_counterfactual':v for k,v in result_counterfactual.items()}
    result.update(result_counterfactual)

    result_factual = f1_metric.compute(predictions=pred_factual_labels, references=target_factual_labels, average='macro')
    result_factual = {f'distillation_{k}_factual':v for k,v in result_factual.items()}
    result.update(result_factual)

    # groundtruth_accuracy
    accuracy_metric = load_metric('accuracy')

    result_counterfactual = accuracy_metric.compute(predictions=pred_counterfactual_labels, references=gt_counterfactual_labels)
    result_counterfactual = {f'groundtruth_{k}_counterfactual':v for k,v in result_counterfactual.items()}
    result.update(result_counterfactual)
    
    result_factual = accuracy_metric.compute(predictions=pred_factual_labels, references=gt_factual_labels)
    result_factual = {f'groundtruth_{k}_factual':v for k,v in result_factual.items()}

    result.update(result_counterfactual)

    # groundtruth_f1
    f1_metric = load_metric('f1')

    result_counterfactual = f1_metric.compute(predictions=pred_counterfactual_labels, references=gt_counterfactual_labels, average='macro')
    result_counterfactual = {f'groundtruth_{k}_counterfactual':v for k,v in result_counterfactual.items()}
    result.update(result_counterfactual)

    result_factual = f1_metric.compute(predictions=pred_factual_labels, references=gt_factual_labels, average='macro')
    result_factual = {f'groundtruth_{k}_factual':v for k,v in result_factual.items()}
    result.update(result_factual)

    # CEBaB L2
    l2_metric = np.linalg.norm(pred_effect - target_effect, ord=2, axis=1)
    l2_metric = np.average(l2_metric)
    
    # CEBaB cosine
    cosine_metric = 1 - cosine_similarity(pred_effect, target_effect)[range(pred_effect.shape[0]), range(pred_effect.shape[0])]
    cosine_metric = np.average(cosine_metric)
    
    # CEBaB normdiff
    normdiff_metric = abs(np.linalg.norm(pred_effect, ord=2, axis=1) - np.linalg.norm(target_effect, ord=2, axis=1))
    normdiff_metric = np.average(normdiff_metric)

    result.update({
        'icace_l2': l2_metric,
        'icace_cosine': cosine_metric,
        'icace_normdiff': normdiff_metric,
    })
    
    return result

def get_best_checkpoint(last_checkpoint, output_dir):
    with open(os.path.join(last_checkpoint,'trainer_state.json'), 'r') as fp:
        last_trainer_config = json.load(fp)
        best_checkpoint_path = last_trainer_config['best_model_checkpoint']
    # need to add a flag to the filename (to deal with some mistakes in our conventions)
    if "approximate" in output_dir and "approximate" not in best_checkpoint_path:
        best_checkpoint_path = "__".join(
            [best_checkpoint_path.split("__")[0]+"__approximate"]+\
            best_checkpoint_path.split("__")[1:]
        )
    elif "inclusive" in output_dir and "inclusive" not in best_checkpoint_path:
        best_checkpoint_path = "__".join(
            [best_checkpoint_path.split("__")[0]+"__inclusive"]+\
            best_checkpoint_path.split("__")[1:]
        )
    return best_checkpoint_path

class E2EExplainer(InclusiveExplainer):
    def __init__(self, model_path, eval_split, output_dir, device = 'cuda', batch_size=128, gradient_accumulation_steps=1):
        self.model_path = model_path
        self.batch_size = batch_size
        self.device = device
        self.output_dir = output_dir
        self.eval_split = eval_split
        self.gradient_accumulation_steps = gradient_accumulation_steps

        # Get tokenizer
        self.tokenizer_path = self.model_path.split('/')[1].split('.')[0]
        if BERT in self.model_path or LSTM in self.model_path:
            self.tokenizer_path = BERT
        elif ROBERTA in self.model_path:
            self.tokenizer_path = ROBERTA
        elif GPT2 in self.model_path:
            self.tokenizer_path = GPT2

        self.tokenizer = AutoTokenizer.from_pretrained(
                self.tokenizer_path,
                cache_dir="./huggingface_cache"
        )
        
        # Get model
        if BERT in self.model_path:
            self.model = BertForFactualCounterfactualSequenceClassification.from_pretrained(
                    self.model_path,
                    cache_dir="./huggingface_cache"
            )
        elif ROBERTA in self.model_path:
            self.model = RobertaForFactualCounterfactualSequenceClassification.from_pretrained(
                    self.model_path,
                    cache_dir="./huggingface_cache"
            )
        elif GPT2 in self.model_path:
            self.model = GPT2ForFactualCounterfactualSequenceClassification.from_pretrained(
                    self.model_path, cache_dir="./huggingface_cache")
            self.tokenizer.pad_token = self.tokenizer.eos_token
        elif LSTM in self.model_path:
            self.model = LSTMForFactualCounterfactualSequenceClassification.from_pretrained(
                    self.model_path, cache_dir="./huggingface_cache")

        # add new special tokens to the tokenizer
        if USE_TOKEN_FLAGS == True:
            aspects = ['food', 'ambiance', 'noise', 'service']
            values = ['Negative', 'Positive', 'unknown', 'no majority']
            self.special_tokens_dict = {'additional_special_tokens': [f'[{aspect}-{value}]' for aspect in aspects for value in values]}

            num_added_toks = self.tokenizer.add_special_tokens(self.special_tokens_dict)
            self.model.resize_token_embeddings(len(self.tokenizer))

        self.model.to(self.device)

        # Trainer args
        # TODO: get these arguments in a conf file
        self.args = TrainingArguments(
            output_dir = self.output_dir,
            do_train = True if self.eval_split == 'dev' else False,
            do_eval = True,
            do_predict= True,
            per_device_train_batch_size = self.batch_size,
            per_device_eval_batch_size = self.batch_size,
            gradient_accumulation_steps = self.gradient_accumulation_steps,
            # training
            max_steps = 4616,
            # max_steps = 3239,
            # max_steps = 2000,
            overwrite_output_dir = False,
            # logging
            logging_strategy = 'steps',
            logging_steps = 20,
            # report_to = 'wandb',
            # evaluating and saving
            evaluation_strategy = 'steps',
            save_strategy = 'steps',
            # eval_steps = 200,
            # save_steps = 200, 
            eval_steps = 50,
            save_steps = 50,
            save_total_limit = 1,
            load_best_model_at_end = True,
            metric_for_best_model = "icace_cosine",
            greater_is_better = False,
        )

        # Detecting last checkpoint.
        last_checkpoint = None
        if os.path.isdir(self.args.output_dir) and not self.args.overwrite_output_dir:
            last_checkpoint = get_last_checkpoint(self.args.output_dir)
            if last_checkpoint is None and len(os.listdir(self.args.output_dir)) > 0:
                raise ValueError(
                    f"Output directory ({self.args.output_dir}) already exists and is not empty. "
                    "Use --overwrite_output_dir to overcome."
                )
            elif last_checkpoint is not None and self.args.resume_from_checkpoint is None:
                print(
                    f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                    "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
                )
        self.last_checkpoint = last_checkpoint

        # Find best checkpoint and load model
        if not self.args.do_train:

            if self.last_checkpoint:
                best_checkpoint_path = get_best_checkpoint(self.last_checkpoint, self.args.output_dir)
                print(f'Loading best model from checkpoint {best_checkpoint_path}.')
                self.model.load_state_dict(torch.load(os.path.join(best_checkpoint_path, 'pytorch_model.bin')))

            self.model.eval()


    def __str__(self):
        return 'E2EExplainer' + '-' + str(self.model)[:10]

    def create_hf_dataset(self, pairs):
        # get the text
        first_sentences = pairs.description_base.to_list()

        if USE_TOKEN_FLAGS == False:
            second_sentences = (pairs.intervention_type + " is " + pairs.intervention_aspect_counterfactual).to_list()
        else:
            # TODO: no majority in intervention aspect counterfactual??!!
            second_sentences = ("[" + pairs.intervention_type + "-" + pairs.intervention_aspect_counterfactual + "]").to_list()

        # get dataset
        if 'prediction_counterfactual' in pairs.columns and len(pairs) > 0:
            # get model predictions
            target_counterfactual_logits = np.stack(pairs.prediction_counterfactual.to_list())
            target_factual_logits = np.stack(pairs.prediction_base.to_list())
            # get ground truth labels
            gt_counterfactual_labels = np.stack(pairs['review_majority_counterfactual'].to_list())
            gt_factual_labels = np.stack(pairs['review_majority_base'].to_list())
            # little hack: stack all the labels in one tensor
            dataset = Dataset.from_dict({
                'first_sentence': first_sentences,
                'second_sentence': second_sentences,
                'labels': np.stack((target_counterfactual_logits, target_factual_logits, gt_counterfactual_labels, gt_factual_labels), axis=1)
            })
        elif len(pairs) == 0:
            # get empty dataset for k = 0
            dataset = Dataset.from_dict({
                'first_sentence': [],
                'second_sentence': [],
                'labels': []
            })
        else:
            dataset = Dataset.from_dict({
                'first_sentence': first_sentences,
                'second_sentence': second_sentences,
            }) 

        # preprocess dataset
        max_seq_length = 128
        def preprocess(example):
            # get both factual and counterfactual inputs (make sure both are same size)
            result_counterfactual = self.tokenizer(example['first_sentence'], example['second_sentence'], padding="max_length", max_length=max_seq_length, truncation=True)
            max_counterfactual_len = len(result_counterfactual['input_ids'][0])
            result_factual = self.tokenizer(example['first_sentence'], padding='max_length', max_length=max_counterfactual_len)

            # rename keys
            result_factual = {f'{k}_factual':v for k,v in result_factual.items()}

            # both factual and counterfactual inputs should have the same length (for downstream batching)
            if not len(result_factual['input_ids_factual'][0]) == len(result_counterfactual['input_ids'][0]):
                print('Alert!')

            # merge
            result = result_factual
            result.update(result_counterfactual)
            return result

        dataset = dataset.map(preprocess, batched=True, remove_columns=['first_sentence', 'second_sentence'])
        return dataset

    
    def fit(self, pairs, singles, classifier, dev_dataset=None):
        # create train and dev dataset
        train_dataset = self.create_hf_dataset(pairs)
        if dev_dataset is not None:
            dev_dataset = self.create_hf_dataset(dev_dataset)

        # eval metric
        metric = load_metric("accuracy")

        def compute_metrics(p):
            # get pred logits
            pred_counterfactual_logits, pred_factual_logits = p.predictions[:,0,:], p.predictions[:,1,:]
            pred_counterfactual_logits = torch.Tensor(pred_counterfactual_logits)
            pred_factual_logits = torch.Tensor(pred_factual_logits)

            # get target probs
            target_counterfactual_probs, target_factual_probs = p.label_ids[:,0,:], p.label_ids[:,1,:]
            target_counterfactual_probs = torch.Tensor(target_counterfactual_probs)
            target_factual_probs = torch.Tensor(target_factual_probs)

            # get gt labels
            gt_counterfactual_labels, gt_factual_labels = p.label_ids[:,2,:], p.label_ids[:,3,:]

            results = calculate_loss_or_metrics(target_counterfactual_probs, target_factual_probs, pred_counterfactual_logits, pred_factual_logits, gt_counterfactual_labels, gt_factual_labels, return_loss=False)
            return results


        # create trainer
        early_stopping = EarlyStoppingCallback(early_stopping_patience=20)
        self.trainer = DistillationTrainer(
            model = self.model,
            args = self.args,
            train_dataset = train_dataset,
            eval_dataset = dev_dataset,
            tokenizer = self.tokenizer,
            compute_metrics = compute_metrics,
            callbacks = [early_stopping]
        )
       

        # evaluate without training
        if self.args.do_train:
            dry_evaluate_metrics = self.trainer.evaluate(eval_dataset = dev_dataset)
            self.trainer.log_metrics('eval', dry_evaluate_metrics)

        # train
        if self.args.do_train and len(pairs) > 0:
            train_result = self.trainer.train(resume_from_checkpoint=self.last_checkpoint)


    def estimate_icace(self, pairs):
        # create test dataset
        test_dataset = self.create_hf_dataset(pairs)

        # predict
        test_result = self.trainer.predict(test_dataset).predictions[:,0,:]
        counterfactual_predictions = torch.Tensor(test_result)

        # apply softmax
        counterfactual_predictions = softmax(counterfactual_predictions, dim=1).numpy()

        # get ICaCE
        factual_predictions = np.stack(pairs.prediction_base.to_list())
        icace = counterfactual_predictions - factual_predictions

        return list(icace)

class DistillationTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        # get factual and counterfactual labels
        target_counterfactual_probs, target_factual_probs = inputs.get("labels")[:,0,:], inputs.get("labels")[:,1,:]

        # get ground truth labels
        gt_counterfactual_labels, gt_factual_labels = inputs.get("labels")[:,2,:], inputs.get("labels")[:,3,:]
        
        # forward pass (remove labels so the model won't try to calculate the loss)
        inputs.pop('labels')
        outputs = model(**inputs)
        
        # get factual and counterfactual predicted logits
        pred_counterfactual_logits, pred_factual_logits = outputs.get("logits")[:,0,:], outputs.get("logits")[:,1,:]

        loss = calculate_loss_or_metrics(target_counterfactual_probs, target_factual_probs, pred_counterfactual_logits, pred_factual_logits, gt_counterfactual_labels, gt_factual_labels, return_loss = True)
        return (loss, outputs) if return_outputs else loss
