import json
from argparse import Namespace
from pprint import pprint
from typing import Dict, List

import hydra
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
from tqdm import tqdm

from lm_understanding.baselines import make_baseline
from lm_understanding.baselines.baseline import BaselineResults
from lm_understanding.data import (baseline_results_path,
                                   load_local_explanations,
                                   load_model_behaviors)
from lm_understanding.explanations.explanations import LocalExplanationSet
from lm_understanding.question_template import (QuestionTemplate,
                                                TemplateModelBehavior)


def make_explainers(config: Namespace, model_behaviors: List[TemplateModelBehavior]) -> List[LocalExplanationSet]:
    if config.baseline.explanation.name != 'none':
        explanations = {e.template_id: e for e in load_local_explanations(config)}
        return [explanations[model_behavior.template_id] for model_behavior in model_behaviors]  # type: ignore
    explainers = []
    for model_behavior in model_behaviors:
        assert model_behavior.template_id is not None
        explainers.append(LocalExplanationSet(model_behavior.template_id, model_behavior.questions('train'), model_behavior.answers('train').tolist()))
    return explainers


def combine_templates(model_behaviors: List[TemplateModelBehavior], explainers: List[LocalExplanationSet]):
    question_behaviors = []
    questions = []
    answers = []
    explanations = []
    for model_behavior, explainer in zip(model_behaviors, explainers):
        question_behaviors.extend(model_behavior.question_behaviors)
        questions.extend(explainer.questions)
        answers.extend(explainer.answers)
        if explainer.explanations is not None:
            explanations.extend(explainer.explanations)
    template = QuestionTemplate.from_str('')
    template.template_id = 'full_task'
    return (
        [TemplateModelBehavior(template, question_behaviors, model_behaviors[0].model_name)],
        [LocalExplanationSet(template.template_id, questions, answers, explanations)]
    )


def evaluate_baseline(baseline_config: Namespace, model_behaviors: List[TemplateModelBehavior], explainers: List[LocalExplanationSet]) -> List[BaselineResults]:
    overall_results = []
    if 'test_on_train_data' in baseline_config and baseline_config.test_on_train_data: 
        split_names = ['train', 'test']
    else: 
        split_names = ['test']
    print(baseline_config.name)
    for template_model_behavior, template_model_explainer in tqdm(list(zip(model_behaviors, explainers)), desc="Evaluating Baselines"):
        baseline = make_baseline(baseline_config, template_model_behavior, template_model_explainer)
        baseline.train()
        baseline_results = baseline.test(split_names)
        baseline_results.baseline = baseline_config.name
        overall_results.append(baseline_results)
    return overall_results


def average_results(baseline_results: List[BaselineResults], split_name: str = 'test') -> Dict[str, float]:
    return dict(
        log_score=np.array([r.log_scores[split_name] for r in baseline_results]).mean(),
        normalized_log_score=np.array([r.normalized_log_scores[split_name] for r in baseline_results]).mean(),
        kl_divergence=np.array([r.kl_divergences[split_name] for r in baseline_results]).mean(),
        accuracy=np.array([r.accuracies[split_name] for r in baseline_results]).mean(),
        tv_distance = np.array([r.tv_distances[split_name] for r in baseline_results]).mean()
    )


def results_to_df(baseline_results: List[BaselineResults]) -> pd.DataFrame:
    records = []
    for baseline_result in baseline_results:
        records.extend(baseline_result.as_records())
    return pd.DataFrame.from_records(records)


def save_results(baseline_results: List[BaselineResults], config: Namespace, info: Dict) -> None:
    results_path = baseline_results_path(config)
    results_path.mkdir(exist_ok=True, parents=True)

    results_to_df(baseline_results).to_csv((results_path / f'{config.baseline.name}.csv'), index=False)

    if 'test_on_train_data' in config.baseline and config.baseline.test_on_train_data:
        results = dict(
            results=average_results(baseline_results, 'test'),
            results_on_train_data=average_results(baseline_results, 'train'),
        )
    else:
        results = dict(
            results=average_results(baseline_results, 'test'),
        )

    run_info = dict(
        config=OmegaConf.to_container(config),
        **info,
        n_templates_evaluated=(len(baseline_results)),
        templates_evaluated=[r.template_id for r in baseline_results],
        **results,
    )
    with open(results_path / f'{config.baseline.name}.json', 'w') as f:
        json.dump(run_info, f, indent=4, ensure_ascii=False)
    pprint(run_info)


@hydra.main(config_path='config', config_name='run_baseline.yaml', version_base='1.2')
def main(config):
    print(config)
    model_behaviors, filtering_info = load_model_behaviors(config)
    explainers = make_explainers(config, model_behaviors)
    if 'full_topic' in config.baseline and config.baseline.full_topic:
        model_behaviors, explainers = combine_templates(model_behaviors, explainers)
    baseline_results = evaluate_baseline(config.baseline, model_behaviors, explainers)
    save_results(baseline_results, config, filtering_info)


if __name__ == '__main__':
    main()