import datetime
import itertools
import json
import logging
import os

import numpy as np
import pandas as pd
from dotenv import load_dotenv

from src.evaluation.quan_eval import evaluate_dataset, perf_degradation, fidelity
from src.models.components.llms.ollama_component import OllamaComponent
from src.models.components.llms.rits_component import RITSComponent
from src.models.datasets.dataset_factory import DatasetFactory
from src.models.global_explainers.global_expl import GlobalExplainer
from src.models.guardians.rits_guardian import RITSGuardian
from src.models.local_explainers.lime import LIME
from src.models.local_explainers.shap_vals import SHAP
from src.pipeline.clusterer import Clusterer
from src.pipeline.concept_extractor import Extractor
from src.pipeline.pipeline import Pipeline
from src.utils.data_util import seed_everything


def run_experiments():
    seed_everything(42)
    load_dotenv()

    # loading datasets
    with open('dataset_map.json') as f:
        data_mappings = json.load(f)

    with open('guardian_map.json') as f:
        guardian_mappings = json.load(f)

    local_explainers = {'lime': LIME}

    llm_components = {'llama3.3:70b': RITSComponent('llama-3-3-70b-instruct', 'meta-llama/llama-3-3-70b-instruct')
                     }

    all_combinations = list(itertools.product(list(llm_components.keys()),
                                               list(guardian_mappings.keys()),
                                               list(local_explainers.keys()),
                                               list(data_mappings.keys())
                                               )
                            )

    logger.info(f'Running {len(all_combinations)} experiments...')

    for experiment in all_combinations:
        llm_component_name, guardian_name, local_expl_name, dataset_name = experiment
        logger.info(f'Running an experiment:\n\tGuardian = {guardian_name}\n\tDataset = {dataset_name}\n\tLocal explainer = {local_expl_name}\n\tLLM Component = {llm_component_name}')

        # defining guardian model
        guardian_config = guardian_mappings[guardian_name]
        guardian = RITSGuardian(guardian_config['rits']['model_name'], guardian_config['rits']['model_served_name'], guardian_config, guardian_name)

        # defining dataset
        dataframe = pd.read_csv(f'datasets/perturbations/{dataset_name}/original.csv', header=0)
        print(len(dataframe))
        dataset = DatasetFactory.get_dataset(data_mappings[dataset_name], dataframe=dataframe)
        print(len(dataset.train))
        ablations = {
                     'no_fr': {
                                'fr': False,
                                'local_expl': True
                                }
        }

        for abl_name, abl_config in ablations.items():
            logger.info(f'Running ablation = {abl_name}...')

            # evaluate the global explanation
            try:
                path = f'results/ablations/{abl_name}/{guardian_name}/{llm_component_name}/{local_expl_name}/'
                # expl = pipeline.run(dataset, path=path)
                expl = GlobalExplainer(expl_path=os.path.join(path, dataset_name, 'global/global_expl.pkl'))
            except FileNotFoundError:
                continue

            evaluated_path = path + f'{dataset_name}/results.csv'
            train_evaluated = evaluate_dataset(expl, guardian, dataset.train, dataset.expl_input, 'expl_answer',
                                               'guard_answer', evaluated_path)

            original_dataset = pd.read_csv(f'results/robustness/{guardian_name}/{llm_component_name}/{local_expl_name}/{dataset_name}/original.csv')

            label_col = dataset.label_col
            if label_col not in original_dataset.columns:
                label_col = 'label'
                original_dataset[label_col] = [1]*len(original_dataset)
                train_evaluated[label_col] = [1] * len(train_evaluated)

            original_fidelity_acc, original_fidelity_f1 = fidelity(original_dataset, 'expl_answer', 'guard_answer', guardian_config["label_names"])
            orig_perf_degr = perf_degradation(original_dataset, 'expl_answer', 'guard_answer', label_col, guardian_config["label_names"])

            perf_degr = perf_degradation(train_evaluated, 'expl_answer', 'guard_answer', label_col, guardian_config["label_names"])
            fidelity_acc, fidelity_f1 = fidelity(train_evaluated, 'expl_answer', 'guard_answer',
                                                 guardian_config["label_names"])


            logger.info('''--GloVE: Performance degradation = {}\nFidelity (acc) = {}\nFidelity (f1) = {}\nGloVE_NOFR:  Performance degradation = {}\nFidelity (acc) = {}\nFidelity (f1) = {}\n '''.format(
                                                                                                        orig_perf_degr,
                                                                                                              original_fidelity_acc,
                                                                                                              original_fidelity_f1,
                                                                                                              perf_degr,
                                                                                                              fidelity_acc,
                                                                                                              fidelity_f1
                                                                                                                 ))


if __name__ == '__main__':
    logger = logging.getLogger('logger')
    logger.setLevel(logging.INFO)
    # create file handler which logs even debug messages
    fh = logging.FileHandler(f'logs/{datetime.datetime.now().strftime("%m_%d__%H_%M")}.log')
    fh.setLevel(logging.INFO)
    logger.addHandler(fh)

    run_experiments()


