import copy
import csv
import os
import random

import numpy as np
import pandas as pd
import torch
import wandb
from torch.nn.functional import cosine_similarity
from tqdm import tqdm
from transformers import AdamW

from info_nce import InfoNCE
from models.model import Model
from utils.constants import DEVICE, CEBAB_CONCEPTS
from utils.metric_utils import cosine_similarity_matrix
from utils.results_utils import make_dir, calculate_cebab_score, calculate_cebab_score_stance_setup
from utils.training_utils import embeddings_for_batch_group_2, embeddings_for_batch_group_1


class MatchingRepresentation(Model):

    def __init__(self, tokenizer, pretrained_model=None, pretrained_model_path=None, model_description=None,
                 examples_per_pair=5, group='group 1', to_device=False, setup_name='cebab', text_column='text'):
        self.pretrained_model_path = pretrained_model_path
        self.pretrained_model = pretrained_model
        self.tokenizer = tokenizer
        self.pretrained_model_path = pretrained_model_path
        self.examples_per_pair = examples_per_pair
        self.group = group
        self.text_column = text_column
        if group == 'group 1':
            self.embeddings_for_batch = embeddings_for_batch_group_1
        elif group == 'group 2':
            self.embeddings_for_batch = embeddings_for_batch_group_2
        if self.pretrained_model is None and self.pretrained_model_path is None:
            raise ValueError('Either pretrained_model or pretrained_model_path must be provided')
        if self.pretrained_model is not None and self.pretrained_model_path is not None:
            raise ValueError('Only one of pretrained_model or pretrained_model_path must be provided')
        self.flag_train = True
        if pretrained_model is None and pretrained_model_path is not None:
            self.flag_train = False
            self.pretrained_model = torch.load(self.pretrained_model_path)
        self.lm_model = copy.deepcopy(self.pretrained_model)
        if to_device:
            self.lm_model.to(DEVICE)
        self.setup_name = setup_name
        super().__init__(pretrained_model_path=self.pretrained_model_path, model_description=model_description)

    def set_lm_model(self, model):
        self.lm_model = model

    def to_device(self):
        self.lm_model.to(DEVICE)

    def get_representation_model(self):
        return self.lm_model

    def get_lm_model(self):
        return self.pretrained_model

    def train(self, config, wandb_on, train_batch, mean_loss, results_dir=None, valid_batch=None, test_batch=None,
              cebab_eval=None, loss_version=0):
        # batches = {'train': train_batch, 'validation': valid_batch, 'test': test_batch}
        rep_model = self.get_representation_model()
        rep_model.to(DEVICE)
        if cebab_eval is None:
            cebab_eval = {'model_to_explain': None, 'explainers': None,
                          'pairs_validation': None, 'pairs_test': None,
                          'matching_description': None}
        else:
            cebab_eval = {'model_to_explain': cebab_eval['model_to_explain'], 'explainers': cebab_eval['explainers'],
                          'pairs_validation': cebab_eval['pairs_validation'], 'pairs_test': cebab_eval['pairs_test'],
                          'matching_description': cebab_eval['matching_description']}
        if loss_version == 0:
            coefficients = {
                'tcf_cfc': config.tcf_cfc,
                'tcf_pax': config.tcf_pax,
                'tcf_nax': config.tcf_nax,
                'pax_nax': config.pax_nax,
                'pax_cfc': config.pax_cfc,
                'cfc_nax': config.cfc_nax,
            }


        elif loss_version == 1:
            coefficients = {'approx': config.approx, 'c_cf': config.c_cf,
                            't_cf': config.t_cf, 'negative': config.negative}
        else:
            raise ValueError('Loss version not supported')

        loss = InfoNCE(negative_mode='unpaired', temperature=config.temperature,
                       coefficients=coefficients, device=DEVICE, mean_loss=mean_loss, loss_version=loss_version)

        if test_batch is not None:
            print('Evaluating model before training')
            p = os.path.join(results_dir, 'epoch_0')
            make_dir(p)
            with torch.no_grad():
                calculate_cebab_score_stance_setup(model_to_explain_name=cebab_eval['model_to_explain'],
                                                   concepts=[config.treatment],
                                                   explainers=cebab_eval['explainers'],
                                                   pairs=cebab_eval['pairs_test'],
                                                   name='test', wandb_on=wandb_on,
                                                   path_dir=None, return_log=False,
                                                   save_outputs=False)
                evaluation(eval_set=test_batch, model=rep_model, loss_function=loss,
                           device=DEVICE,
                           epoch=0, wandb_on=wandb_on, coefficients=coefficients,
                           examples_per_pair=self.examples_per_pair,
                           name='test', path_dir=p, config=config,
                           return_loss=False, embeddings_for_batch_function=self.embeddings_for_batch)

        rep_model.train()
        rep_model.to(DEVICE)

        optimizer = AdamW(rep_model.parameters(), lr=config.learning_rate)

        best_model = None
        best_eval_loss = np.inf
        best_epoch = 0

        for epoch in range(config.epochs):
            progress_bar = tqdm(train_batch)
            for i, t_batch in enumerate(progress_bar):
                query = t_batch['query']
                t_cf = t_batch['t_cf']
                c_cf = t_batch['c_cf']
                negative = t_batch['negative']
                approx = t_batch['approx']

                rep_model, query_embedding, t_cf_embeddings, approx_embeddings, c_cf_embeddings, negative_embeddings = self.embeddings_for_batch(
                    rep_model=rep_model, query=query, t_cf=t_cf, negative=negative, c_cf=c_cf, approx=approx,
                    examples_per_pair=self.examples_per_pair)
                outputs = loss(query=query_embedding, positive_keys=t_cf_embeddings,
                               negative_keys=negative_embeddings, confounder_counterfactual_keys=c_cf_embeddings,
                               approx_keys=approx_embeddings)

                # In case of gaining gradients from the previous batches, change this condition
                if True:
                    outputs.backward()
                    optimizer.step()
                    optimizer.zero_grad()

                progress_bar.set_description(
                    f'Epoch {epoch + 1}/{config.epochs} | Loss: {outputs.item():.4f} |')

                if ((i + 1) % (len(train_batch) // 2)) == 0 or (i + 1 == len(progress_bar)):
                    if (i + 1) == len(progress_bar):
                        path_per_epoch = os.path.join(results_dir, f'epoch_{epoch + 1}')
                        make_dir(path_per_epoch)
                        # torch.save(rep_model.eval(), os.path.join(path_per_epoch, f'model.pt'))
                        if test_batch is not None:
                            with torch.no_grad():
                                evaluation(eval_set=test_batch, model=rep_model, loss_function=loss,
                                           device=DEVICE,
                                           epoch=epoch, wandb_on=wandb_on, coefficients=coefficients,
                                           examples_per_pair=self.examples_per_pair,
                                           name='test', path_dir=path_per_epoch, config=config,
                                           return_loss=False, embeddings_for_batch_function=self.embeddings_for_batch)
                        if cebab_eval['pairs_test'] is not None:
                            if config.treatment == 'all':
                                concept = CEBAB_CONCEPTS
                            else:
                                concept = [config.treatment]
                            if self.setup_name == 'cebab':
                                calculate_cebab_score(model_to_explain=cebab_eval['model_to_explain'],
                                                      concepts=concept,
                                                      explainers=cebab_eval['explainers'],
                                                      pairs=cebab_eval['pairs_test'],
                                                      name='test',
                                                      wandb_on=wandb_on, return_log=False, path_dir=path_per_epoch)
                            elif self.setup_name == 'stance':
                                calculate_cebab_score_stance_setup(model_to_explain_name=cebab_eval['model_to_explain'],
                                                                   concepts=concept,
                                                                   explainers=cebab_eval['explainers'],
                                                                   pairs=cebab_eval['pairs_test'],
                                                                   name='test', wandb_on=wandb_on,
                                                                   path_dir=None, return_log=False,
                                                                   save_outputs=False)



                    else:
                        path_per_epoch = None

                    print(f'starting evaluation, epoch = {epoch + 1}, iteration = {i}.')
                    rep_model.eval()
                    with torch.no_grad():
                        if valid_batch is not None:
                            loss_eval = evaluation(eval_set=valid_batch, model=rep_model, loss_function=loss,
                                                   device=DEVICE,
                                                   epoch=epoch, wandb_on=wandb_on, coefficients=coefficients,
                                                   examples_per_pair=self.examples_per_pair,
                                                   name='validation', path_dir=path_per_epoch, config=config,
                                                   return_loss=True,
                                                   embeddings_for_batch_function=self.embeddings_for_batch)

                            if best_eval_loss > loss_eval:
                                best_eval_loss = loss_eval
                                best_model = copy.deepcopy(rep_model)
                                best_epoch = epoch
                        else:
                            print('cebab eval is None')
                            best_model = rep_model
                            best_epoch = epoch

                        if (cebab_eval['model_to_explain'] is not None) and (
                                cebab_eval['explainers'] is not None):
                            self.set_lm_model(rep_model)
                            for e in cebab_eval['explainers']:
                                e.set_representation_model(self)
                            if cebab_eval['pairs_validation'] is not None:
                                if config.treatment == 'all':
                                    concept = CEBAB_CONCEPTS
                                else:
                                    concept = [config.treatment]
                                if self.setup_name == 'cebab':
                                    calculate_cebab_score(model_to_explain=cebab_eval['model_to_explain'],
                                                          concepts=concept,
                                                          explainers=cebab_eval['explainers'],
                                                          pairs=cebab_eval['pairs_validation'],
                                                          name='validation',
                                                          wandb_on=wandb_on, path_dir=path_per_epoch)
                                elif self.setup_name == 'stance':
                                    calculate_cebab_score_stance_setup(
                                        model_to_explain_name=cebab_eval['model_to_explain'],
                                        concepts=concept,
                                        explainers=cebab_eval['explainers'],
                                        pairs=cebab_eval['pairs_validation'],
                                        name='validation', wandb_on=wandb_on, path_dir=None,
                                        return_log=False, save_outputs=False)
                    rep_model.to(DEVICE)
                    rep_model.train()

                del t_batch
                torch.cuda.empty_cache()
        if results_dir:
            best_model_path = os.path.join(results_dir, f'best_model')
            make_dir(best_model_path)
            with open(f'{best_model_path}/description.txt', "w") as file:
                file.write(f"best epoch-{best_epoch + 1}, best loss: {best_eval_loss}")
            print(f'saving the model- best epoch{best_epoch}, best loss {best_eval_loss}')

            torch.save(best_model, os.path.join(best_model_path, f'model.pt'))
            torch.save(rep_model, os.path.join(best_model_path, f'last_model.pt'))
        self.set_lm_model(best_model.eval())
        print('return the model')
        return best_model.eval()

    def get_tokenizer(self):
        return self.tokenizer

    def get_model_description(self):
        return 'matching_representation'

    def get_model_group(self):
        return self.group


def evaluation(eval_set, model, loss_function, wandb_on, device, epoch, name, coefficients, config, examples_per_pair,
               embeddings_for_batch_function,
               path_dir=None,
               return_loss=False, best=False):
    print('starting evaluation on the {} set'.format(name))
    # add here cebab score

    model.eval()
    pos_distances = []
    neg_distances = []
    approx_distances = []
    confounder_distances = []
    losses = []
    losses_fixed_c = []

    model = model.to(device)

    with torch.no_grad():
        for i, t_batch in enumerate(tqdm(eval_set)):
            query = t_batch['query']
            t_cf = t_batch['t_cf']
            c_cf = t_batch['c_cf']
            negative = t_batch['negative']
            approx = t_batch['approx']

            rep_model, query_embedding, t_cf_embeddings, approx_embeddings, c_cf_embeddings, negative_embeddings = embeddings_for_batch_function(
                rep_model=model, query=query, t_cf=t_cf, negative=negative, c_cf=c_cf, approx=approx,
                examples_per_pair=examples_per_pair)

            losses.append(
                loss_function(query=query_embedding, positive_keys=t_cf_embeddings, negative_keys=negative_embeddings,
                              confounder_counterfactual_keys=c_cf_embeddings,
                              approx_keys=approx_embeddings).detach().cpu().numpy())
            losses_fixed_c.append(
                loss_function(query=query_embedding, positive_keys=t_cf_embeddings, negative_keys=negative_embeddings,
                              confounder_counterfactual_keys=c_cf_embeddings, approx_keys=approx_embeddings,
                              eval=True).detach().cpu().numpy())

            similarities = cosine_similarity_matrix(query_embedding, t_cf_embeddings).mean()
            pos_distances.append(np.mean(similarities))
            if c_cf_embeddings is not None:
                similarities = cosine_similarity_matrix(query_embedding, c_cf_embeddings).mean()
                confounder_distances.append(np.mean(similarities))

            similarities = cosine_similarity_matrix(query_embedding, approx_embeddings).mean()
            approx_distances.append(np.mean(similarities))

            similarities = cosine_similarity_matrix(query_embedding, negative_embeddings).mean()
            neg_distances.append(np.mean(similarities))

            del t_batch
            torch.cuda.empty_cache()

    model.cpu()

    pos_mean = np.mean(pos_distances)
    if len(confounder_distances) > 0:
        confounder_mean = np.mean(confounder_distances)
    else:
        confounder_mean = -1
    approx_mean = np.mean(approx_distances)
    neg_mean = np.mean(neg_distances)

    differences_positive = pos_mean - approx_mean
    app_conf = approx_mean - confounder_mean
    differences_negative = confounder_mean - neg_mean

    is_order_relation = 0
    if (differences_positive > 0) and (differences_negative > 0) and (
            app_conf > 0):
        is_order_relation = 1

    if best:
        name = f'best_{name}'

    log = {f'treatment-counterfactual_{name}': round(pos_mean, 2),
           f'approx_{name}': round(approx_mean, 2),
           f'negative_{name}': round(neg_mean, 2),
           f'confounder-counterfactual_{name}': round(confounder_mean, 2),
           f'loss_{name}': round(np.mean(losses), 2),
           f'loss_fixed_coefficients_{name}': round(np.mean(losses_fixed_c), 2),
           f'is_order_relation_{name}': is_order_relation}
    if path_dir is not None:
        log_to_df = {}
        for key in log.keys():
            log_to_df[key] = [log[key]]
        df = pd.DataFrame(log_to_df)
        df.to_csv(os.path.join(path_dir, f'{name}.csv'))

    if wandb_on:
        # wandb.init()
        wandb.log(log)
    print(log)

    if return_loss:
        return np.mean(losses)
