import argparse
import os
import shutil

os.environ['TRANSFORMERS_CACHE'] = './huggingface_cache'
os.environ['HF_HOME'] = './huggingface_cache'

import datasets
import numpy as np
import pandas as pd

import sys
sys.path.append(os.path.join(os.path.dirname("__file__"), '..'))

from explainers.inclusive_explainers.e2e_explainer import get_best_checkpoint, get_last_checkpoint
from eval_pipeline.explainers import CONEXP, CausaLM, INLP, ConceptShap, TCAV, RandomExplainer, SLearner, ApproxCounterfactual, RandomCounterfactual
from eval_pipeline.explainers import ATEExplainer, CaCEExplainer, CONATEExplainer, GPT3Counterfactual, E2EExplainer, RandomInclusiveExplainer
from eval_pipeline.models import BERTForCEBaB, RoBERTaForCEBaB, GPT2ForCEBaB, LSTMForCEBaB, LRForCEBaB
from eval_pipeline.pipeline import run_pipelines
# TODO: get rid of these seed maps or describe somewhere how they work
from eval_pipeline.utils import (
    OPENTABLE_BINARY,
    OPENTABLE_TERNARY,
    OPENTABLE_5_WAY,
    BERT,
    GPT2,
    ROBERTA,
    LSTM,
    LR,
    K_ARRAY,
    SEEDS_ELDAR2ZEN,
    preprocess_hf_dataset,
    preprocess_hf_dataset_inclusive,
    save_output,
    average_over_seeds, SEEDS_ELDAR, TREATMENTS
)


def get_exclusive_explainers(seed, args, model):
    num_classes = int(args.num_classes[0])
    causalm = CausaLM(
        factual_model_path=f'CEBaB/{args.model_architecture}.CEBaB.causalm.factual.{args.num_classes}.exclusive.seed_{seed}',
        ambiance_model_path=f'CEBaB/{args.model_architecture}.CEBaB.causalm.ambiance.{args.num_classes}.exclusive.seed_{seed}',
        food_model_path=f'CEBaB/{args.model_architecture}.CEBaB.causalm.food.{args.num_classes}.exclusive.seed_{seed}',
        noise_model_path=f'CEBaB/{args.model_architecture}.CEBaB.causalm.noise.{args.num_classes}.exclusive.seed_{seed}',
        service_model_path=f'CEBaB/{args.model_architecture}.CEBaB.causalm.service.{args.num_classes}.exclusive.seed_{seed}',
        empty_cache_after_run=True,
        device=args.device,
        batch_size=args.batch_size
    )
    slearner = SLearner(
        f'CEBaB/bert-base-uncased.CEBaB.absa.exclusive.seed_{SEEDS_ELDAR2ZEN[int(seed)]}', 
        device=args.device, 
        batch_size=args.batch_size
    )
    explainers = [
        RandomExplainer(),
        CONEXP(),
        slearner,
        causalm,
        TCAV(treatments=TREATMENTS, device=args.device, batch_size=args.batch_size, num_classes = num_classes),
        INLP(treatments=TREATMENTS, device=args.device, batch_size=args.batch_size),
        ConceptShap(concepts=TREATMENTS, original_model=model, device=args.device, verbose=args.verbose, batch_size=args.batch_size)
    ]
    return explainers


def get_inclusive_explainers(seed, args, k, model):
    num_classes = int(args.num_classes[0])
    # TODO: we can also add exclusive explainers here
    # slearner = SLearner(
    #     f'CEBaB/bert-base-uncased.CEBaB.absa.exclusive.seed_{SEEDS_ELDAR2ZEN[int(seed)]}', 
    #     device=args.device, 
    #     batch_size=args.batch_size
    # )
    # approx = ApproxCounterfactual(
    #     f'CEBaB/bert-base-uncased.CEBaB.absa.exclusive.seed_{SEEDS_ELDAR2ZEN[int(seed)]}', 
    #     device=args.device, 
    #     batch_size=args.batch_size,
    #     num_classes = num_classes
    # )
    # random = RandomCounterfactual(
    #     f'CEBaB/bert-base-uncased.CEBaB.absa.exclusive.seed_{SEEDS_ELDAR2ZEN[int(seed)]}', 
    #     device=args.device, 
    #     batch_size=args.batch_size,
    #     num_classes = num_classes
    # )
    # sentence = SEmbeddingCounterfactual(
    #     device=args.device, 
    #     batch_size=args.batch_size,
    #     num_classes = num_classes
    # )
    
    if not os.path.isdir(args.model_output_dir):
        os.mkdir(args.model_output_dir)

    output_dir = f'{args.model_output_dir}/e2e_{args.task_name}__{args.train_setting}__{k}-shot__seed-{SEEDS_ELDAR2ZEN[int(seed)]}__{args.model_architecture}'

    e2e = E2EExplainer(
        f'CEBaB/{args.model_architecture}.CEBaB.sa.{args.num_classes}.exclusive.seed_{SEEDS_ELDAR2ZEN[int(seed)]}', 
        args.eval_split,
        output_dir,
        batch_size = args.batch_size,
        device = args.device,
        gradient_accumulation_steps = args.gradient_accumulation_steps,
    )
    
    # gpt3 = GPT3Counterfactual('./GPT3_output_true/', k = k, eval_split = args.eval_split, num_classes = num_classes)

    # random1 = RandomInclusiveExplainer(random_factual=True, use_real_inputs=False)
    # random2 = RandomInclusiveExplainer(random_factual=True, use_real_inputs=True)
    # random3 = RandomInclusiveExplainer(random_factual=False, use_real_inputs=False)
    # random4 = RandomInclusiveExplainer(random_factual=False, use_real_inputs=True)

    return [
        # approx, 
        # random,
        # slearner,
        # sentence,
        # ATEExplainer(), 
        # CaCEExplainer(),
        # CONATEExplainer(['food']), 
        # CONATEExplainer(['service', 'food']),
        # gpt3,
        e2e 
        # random1,
        # random2,
        # random3,
        # random4,
    ]


def get_model(seed, args, k):
    if args.cpm_self_evaluate and k > 0:
        path = f'./{args.model_output_dir}/e2e_{args.task_name}__{args.train_setting}__{k}-shot__seed-{SEEDS_ELDAR2ZEN[int(seed)]}__{args.model_architecture}'
        # get best checkpoint
        checkpoint = get_last_checkpoint(path)
        checkpoint = get_best_checkpoint(checkpoint, path)
        path = checkpoint
    else:
        path = f'CEBaB/{args.model_architecture}.CEBaB.sa.{args.num_classes}.exclusive.seed_{SEEDS_ELDAR2ZEN[int(seed)]}'
    
    print(f"Loading model from: {path}")

    if args.model_architecture == BERT:
        return BERTForCEBaB(path, device=args.device)
    elif args.model_architecture == ROBERTA:
        return RoBERTaForCEBaB(path, device=args.device)
    elif args.model_architecture == LSTM:
        return LSTMForCEBaB(path, device=args.device)
    elif args.model_architecture == GPT2:
        return GPT2ForCEBaB(path, device=args.device)
    elif args.model_architecture == LR:
        return LRForCEBaB()

      
def main():
    # TODO: add explanations of these arguments or examples

    # arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--task_name', type=str, default=OPENTABLE_BINARY)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--verbose', type=bool, default=False)
    parser.add_argument('--seeds', type=int, nargs='+', default=SEEDS_ELDAR[:3])
    parser.add_argument('--model_architecture', type=str, default=BERT)
    parser.add_argument('--output_dir', type=str, default='output')
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--eval_split', type=str, default='dev')
    parser.add_argument('--flush_cache', type=bool, default=False)
    parser.add_argument('--train_setting', type=str, default='exclusive')
    parser.add_argument('--k_array', type=int, nargs='+', default=K_ARRAY)
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
    parser.add_argument('--model_output_dir', type=str, default='model_output')
    parser.add_argument('--cpm_self_evaluate', type=bool, default=False)
    args = parser.parse_args()

    if args.train_setting not in ['exclusive', 'inclusive', 'approximate']:
        raise ValueError(f'Unsupported train setting \"{args.train_setting}\".')

    if args.cpm_self_evaluate and args.eval_split == 'dev':
        raise ValueError(f"Can only self-evaluate the CPM on test set.")

    # data
    cebab = datasets.load_dataset('CEBaB/CEBaB')
    if args.task_name == OPENTABLE_BINARY:
        args.dataset_type = '2-way'
        args.num_classes = '2-class'
    elif args.task_name == OPENTABLE_TERNARY:
        args.dataset_type = '3-way'
        args.num_classes = '3-class'
    elif args.task_name == OPENTABLE_5_WAY:
        args.dataset_type = '5-way'
        args.num_classes = '5-class'
    else:
        raise ValueError(f'Unsupported task \"{args.task_name}\"')

    # TODO: add inclusive
    if args.train_setting in ['inclusive', 'approximate']:
        train, dev, test = preprocess_hf_dataset_inclusive(cebab, verbose=1, dataset_type=args.dataset_type)
    elif args.train_setting == 'exclusive':
        train, dev, test = preprocess_hf_dataset(cebab, verbose=1, dataset_type=args.dataset_type)

    if args.train_setting == 'approximate':
        k_array = [0]
    else:
        k_array = args.k_array
    
    # check if k's are valid
    # TODO: this is wrong!
    # TODO: should set k based on the amount of actual counterfactuals
    # len_pairs = len(train[1])
    # k_array = [value if value <= len_pairs else len_pairs for value in args.k_array ]
    # k_array = list(np.unique(np.array(k_array)))

    print(f'Running experiments for k in {k_array}')

    # for every k
    for k in k_array:

        # for every seed
        pipeline_outputs = []
        for seed in args.seeds:
            # TODO: support multiple models
            model = get_model(seed, args, k)

            # TODO: add inclusive
            if args.train_setting in ['inclusive', 'approximate']:
                explainers = get_inclusive_explainers(seed, args, k, model)
            elif args.train_setting == 'exclusive':
                explainers = get_exclusive_explainers(seed, args, model)

            # TODO: these are shallow model copies! If one explainer manipulates a model without copying, this could give bugs for other methods!
            models = [model] * len(explainers)

            eval_dataset = dev if args.eval_split == 'dev' else test
            # TODO: add inclusive
            pipeline_output = run_pipelines(models, explainers, train, eval_dataset, seed, k, dataset_type=args.dataset_type, shorten_model_name=True, train_setting=args.train_setting, approximate= args.train_setting == 'approximate')
            pipeline_outputs.append(pipeline_output)

        # average over the seeds
        pipeline_outputs_averaged = average_over_seeds(pipeline_outputs)

        # save output
        # TODO: add inclusive, add k
        if args.output_dir:
            if args.cpm_self_evaluate:
                filename_suffix = f'{args.task_name}__{args.train_setting}-cpm-self-evaluate__{k}-shot__{args.model_architecture}__{args.eval_split}'
                save_output(os.path.join(args.output_dir, f'{args.train_setting}-cpm-self-evaluate', f'{k}-shot', f'final__{filename_suffix}'), filename_suffix, *pipeline_outputs_averaged)
            else:
                filename_suffix = f'{args.task_name}__{args.train_setting}__{k}-shot__{args.model_architecture}__{args.eval_split}'
                save_output(os.path.join(args.output_dir, args.train_setting, f'{k}-shot', f'final__{filename_suffix}'), filename_suffix, *pipeline_outputs_averaged)

    if args.flush_cache:
        home = os.path.expanduser('~')
        hf_cache = os.path.join(home, '.cache', 'huggingface', 'transformers')
        print(f'Deleting HuggingFace cache at {hf_cache}.')
        shutil.rmtree(hf_cache, ignore_errors=True)

if __name__ == '__main__':
    main()
