import copy
import itertools
import json
import math
from argparse import Namespace
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple

import hydra
import numpy as np
from omegaconf import OmegaConf
from tqdm import tqdm

from lm_understanding.baselines import make_baseline
from lm_understanding.baselines.baseline import BaselineResults
from lm_understanding.data import (dataset_path, model_behavior_path,
                                   template_path)
from lm_understanding.models import load_model
from lm_understanding.prompting.completion import Completer
from lm_understanding.prompting.prompting import Prompter
from lm_understanding.question_template import (QuestionTemplate, Substitution,
                                                TemplateDataset,
                                                TemplateModelBehavior,
                                                TemplateTask, ValueCombination,
                                                Variables,
                                                answer_template_questions,
                                                load_question_templates,
                                                make_question_df,
                                                measure_template_diversity)


def filter_info(templates: List[QuestionTemplate], filtered_templates: List[QuestionTemplate], filtering_model_info: Dict, filtering_metric: List[float]):
    return dict(
        kept_fraction=len(filtered_templates) / len(templates),
        pre_filter_diversity=measure_template_diversity(templates),
        post_filter_diversity=measure_template_diversity(filtered_templates),
        template_etvds={t.template_id: etvd for t, etvd in zip(templates, filtering_metric)},
        **filtering_model_info
    )


def update_template_metadata(templates: List[QuestionTemplate], model_name: str, filtering_metric: List[float]) -> None:
    for t, etvd in zip(templates, filtering_metric):
        t.metadata[f'filter_by_{model_name}'] = dict(expected_tv_distance_between_answers=etvd)


def measure_template_etvds(templates: List[QuestionTemplate], config: Namespace) -> Tuple[List[float], Dict]:
    model, prompter = load_model(config.model)
    results = [
        answer_template_questions(template, template.variables, model, n_questions=config.initial_filter.samples, prompter=prompter, seed=config.seed, sample_fast=True)
        for template in tqdm(templates, desc=f'Evaluating Templates: {model.name}')
    ]
    return [r.etvd for r in results], model.info


def lookup_template_etvds(templates: List[QuestionTemplate], config: Namespace) -> Tuple[List[float], Dict]:
    data = json.loads((dataset_path(config) / 'run_info.json').read_text())
    template_etvd_lookup = data['initial_filter']['template_etvds']
    return [template_etvd_lookup[t.template_id] for t in templates], dict()


def filter_templates(templates: List[QuestionTemplate], config: Namespace) -> Tuple[List[QuestionTemplate], Dict]:
    if not templates:
        return [], dict()
    get_template_etvds = lookup_template_etvds if dataset_path(config).exists() else measure_template_etvds
    template_etvds, model_info = get_template_etvds(templates, config)

    filtered_templates = [t for t, etvd in zip(templates, template_etvds) if etvd >= config.initial_filter.threshold]
    
    update_template_metadata(templates, config.model.name, template_etvds)
    print(f'Templates Kept: {len(filtered_templates)} / {len(templates)}')
    
    return filtered_templates, filter_info(templates, filtered_templates, model_info, template_etvds)


def use_substitution_in_combinations(substitution: Substitution, combinations: List[ValueCombination]) ->  List[ValueCombination]:
    return [combination.copy_with_replaced_substitution(substitution) for combination in combinations]


def answer_probs_for_variable_values_across_model_and_questions(model_behaviors: List[TemplateModelBehavior]) -> np.ndarray:
    variable_values = model_behaviors[0].split_names
    variable_answer_probs = []
    for variable_value in variable_values:
        variable_answer_probs.append(np.stack([model_behavior.answers(variable_value) for model_behavior in model_behaviors], axis=0))
    return np.stack(variable_answer_probs, axis=0)  # variable_value, model, question


def variable_reponsiveness(model_behaviors: List[TemplateModelBehavior]):
    answer_probs = answer_probs_for_variable_values_across_model_and_questions(model_behaviors)
    return answer_probs.var(axis=0, keepdims=True).mean(axis=2, keepdims=True).mean()

@dataclass
class VariableSelector:
    template: QuestionTemplate
    possible_shift_variables: List[str]
    eval_questions_per_value: int
    distributional_shift_variable: Optional[str] = None
    model_behaviors: defaultdict = field(default_factory=lambda: defaultdict(list))

    @property
    def template_id(self):
        return self.template.template_id

    def combinations_across_all_variable_values(self, variable_name: str) -> TemplateTask:
        template = self.template
        sampled_combination = template.variables.sample(self.eval_questions_per_value, seed=0, fast_with_replacement=True)
        possible_values = template.variables[variable_name]
        possible_substitutions = [Substitution(variable_name, v) for v in possible_values]
        value_eval_combinations = {sub.value: use_substitution_in_combinations(sub, sampled_combination) for sub in possible_substitutions}
        return TemplateTask(template, {value: make_question_df(template, combinations) for value, combinations in value_eval_combinations.items()})

    def variable_evaluation_tasks(self):
        return {variable_name: self.combinations_across_all_variable_values(variable_name) for variable_name in self.template.variable_names}

    def eval_model_behavior(self, completer: Completer, prompter: Prompter) -> None:
        variable_tasks = self.variable_evaluation_tasks()
        for variable_name, eval_task in tqdm(list(variable_tasks.items()) , desc=f"Selecting Distributional Shift Variable: {self.template_id}: {completer.name}"):
            self.model_behaviors[variable_name].append(eval_task.evaluate(completer, prompter, verbose=False))

    def select_distributional_shift_variable(self) -> str:
        variable_reponsivenesses = [
            (variable_name, variable_reponsiveness(model_behaviors))
            for variable_name, model_behaviors in self.model_behaviors.items()
        ]
        self.distributional_shift_variable = max(variable_reponsivenesses, key=lambda x: x[1])[0]
        return self.distributional_shift_variable


def make_distributional_shift_template_task(template: QuestionTemplate, distributional_shift_variable: str, config: Namespace) -> TemplateTask:
    train_value = template.variables[distributional_shift_variable][0]
    train_variables = template.variables.isolate_value(distributional_shift_variable, train_value)
    test_variables = template.variables.exclude_value(distributional_shift_variable, train_value)
    train_combinations = train_variables.sample(config.questions_per_split.train, config.seed)
    test_combinations = test_variables.sample(config.questions_per_split.test, config.seed)
    split_questions = dict(
        train=make_question_df(template, train_combinations),
        test=make_question_df(template, test_combinations)
    )
    return TemplateTask(template, questions=split_questions)


def create_new_variable_distributional_shift_dataset(templates: List[QuestionTemplate], config: Namespace) -> TemplateDataset:
    variable_selectors = [
        VariableSelector(template, config.task.possible_shift_variables, config.distributional_shift.value_eval_questions)
        for template in templates
    ]
    model_configs = [config.model]
    for model_config in model_configs:
        completer, prompter = load_model(model_config)
        for variable_selector in variable_selectors:
            variable_selector.eval_model_behavior(completer, prompter) 
    distributional_shift_variables = [selector.select_distributional_shift_variable() for selector in variable_selectors]
    template_tasks = [
        make_distributional_shift_template_task(template, shift_variable, config)
        for template, shift_variable in tqdm(list(zip(templates, distributional_shift_variables)), 'Creating Datasets')
    ]
    return TemplateDataset(template_tasks)


def new_combination_test_variables(template: QuestionTemplate, config: Namespace) -> Variables:
    test_variables = copy.deepcopy(template.variables)
    for primary_variable in template.variable_names:
        test_variables.possible_values[primary_variable] = test_variables.possible_values[primary_variable][config.distributional_shift.train_variables:]
    return test_variables


def new_combination_train_variables(template: QuestionTemplate, config: Namespace) -> List[Variables]:
    all_train_variables = []
    for primary_variable in template.variable_names:
        train_variables = copy.deepcopy(template.variables)
        # don't include train variables in primary variable to prevent redundant sampling
        train_variables.possible_values[primary_variable] = train_variables.possible_values[primary_variable][config.distributional_shift.train_variables:]
        if config.distributional_shift.new_values:
            train_variables.possible_values[primary_variable] = train_variables.possible_values[primary_variable][:-config.distributional_shift.new_values]
        for other_variable in (v for v in template.variable_names if v != primary_variable):
            train_variables.possible_values[other_variable] = train_variables.possible_values[other_variable][:config.distributional_shift.train_variables]
        all_train_variables.append(train_variables)
    return all_train_variables


def create_new_combination_task(
        template: QuestionTemplate,
        config: Namespace,
        ) -> TemplateTask:
    all_train_variables = new_combination_train_variables(template, config)    
    samples_per_train_variable = math.ceil(config.questions_per_split.train / len(all_train_variables))
    train_combinations = [c for v in all_train_variables for c in v.sample(samples_per_train_variable)]
    
    test_variables = new_combination_test_variables(template, config)
    test_combinations = test_variables.sample(config.questions_per_split.test)
    
    questions = dict(
        train=make_question_df(template, train_combinations),
        test=make_question_df(template, test_combinations)
    )
    return TemplateTask(template, questions)


def create_new_combination_distributional_shift_dataset(templates: List[QuestionTemplate], config) -> TemplateDataset:
    template_tasks = [create_new_combination_task(t, config) for t in tqdm(templates, desc='Creating Dataset')]
    return TemplateDataset(template_tasks)


def create_new_values_task(
        template: QuestionTemplate,
        config: Namespace,
        ) -> TemplateTask:
    train_variables = copy.deepcopy(template.variables)
    all_test_variables = []
    for variable_name in template.variable_names:
        train_variables[variable_name] = train_variables[variable_name][:config.distributional_shift.train_values]
    for new_variable_combination in itertools.combinations(template.variable_names, config.distributional_shift.shifted_variables):
        test_variables = copy.deepcopy(template.variables)
        for variable_name in template.variable_names:
            if variable_name in new_variable_combination:
                test_variables[variable_name] = test_variables[variable_name][config.distributional_shift.train_values:]
            else:
                test_variables[variable_name] = test_variables[variable_name][:config.distributional_shift.train_values]
        all_test_variables.append(test_variables)
    samples_per_test_variable = math.ceil(config.questions_per_split.test / len(all_test_variables))
    test_combinations = [c for v in all_test_variables for c in v.sample(samples_per_test_variable)]
    train_combinations = train_variables.sample(config.questions_per_split.test)
    questions = dict(
        train=make_question_df(template, train_combinations),
        test=make_question_df(template, test_combinations)
    )
    return TemplateTask(template, questions)


def create_new_values_distributional_shift_dataset(templates: List[QuestionTemplate], config) -> TemplateDataset:
    template_tasks = [create_new_values_task(t, config) for t in tqdm(templates, desc='Creating Dataset')]
    return TemplateDataset(template_tasks)


def create_distributional_shift_dataset(templates: List[QuestionTemplate], config) -> TemplateDataset:
    dataset_creation_fns = dict(
        new_combinations=create_new_combination_distributional_shift_dataset,
        new_variable=create_new_variable_distributional_shift_dataset,
        new_values=create_new_values_distributional_shift_dataset
    )
    dataset_creation_fn = dataset_creation_fns[config.distributional_shift.name]
    return dataset_creation_fn(templates, config)


@dataclass
class TemplateBaselineEvaluation:
    template_task: TemplateTask
    model_behaviors: List[TemplateModelBehavior] = field(default_factory=list)
    baseline_results: List[BaselineResults] = field(default_factory=list)

    @property
    def model_names(self) -> List[str]:
        return [b.model_name for b in self.model_behaviors]

    @property
    def summary(self) -> Dict:
        baseline_filtering_metadata = dict()
        baseline_filtering_metadata['model_behavior'] = {
            behavior.model_name: behavior.summary
            for behavior in self.model_behaviors
        }
        baseline_filtering_metadata['baseline_results'] = {
            model_name: {r.baseline: r.summary for r in model_results}
            for model_name, model_results
            in self.baseline_results_by_model().items() 
        }
        return baseline_filtering_metadata

    @property
    def max_log_score(self) -> float:
        return max([results.normalized_log_scores['test'] for results in self.baseline_results])

    def baseline_results_for_model(self, model_name: str) -> List[BaselineResults]:
        return [result for result in self.baseline_results if result.model_name == model_name]

    def baseline_results_by_model(self) -> Dict[str, List[BaselineResults]]:
        return {model_name: self.baseline_results_for_model(model_name) for model_name in self.model_names}
            
    def update_template_metadata(self) -> None:
        self.template_task.template.metadata['baseline_filtering'] = self.summary


@dataclass
class TemplateBaselineFilterer:
    baseline_evals: List[TemplateBaselineEvaluation]
    config: Namespace

    @property
    def baseline_model_configs(self):
        return [self.config.model]

    @property
    def baseline_configs(self):
        return [v for k, v in self.config.baseline_filter.items() if 'baseline_' in k and k != 'baseline_filter']

    def collect_model_behavior(self) -> None:
        for model_config in self.baseline_model_configs:
            completer, prompter = load_model(model_config)
            for evaluation in tqdm(self.baseline_evals, desc=f'Collecting Model Behavior {model_config.name}'):
                template_id = evaluation.template_task.template_id
                assert template_id
                model_behavior_dir = model_behavior_path(self.config) / template_id
                model_behavior = None
                if model_behavior_dir.exists():
                    existing_model_behavior = TemplateModelBehavior.load(model_behavior_dir)
                    existing_df = existing_model_behavior.df
                    existing_train_questions = existing_df[existing_df.split == 'train'].question_text
                    existing_test_questions = existing_df[existing_df.split == 'test'].question_text
                    if (evaluation.template_task.questions['train'].question_text.tolist() == existing_train_questions.tolist()) and (evaluation.template_task.questions['test'].question_text.tolist() == existing_test_questions.tolist()):
                        model_behavior = existing_model_behavior
                if model_behavior is None:
                    model_behavior = evaluation.template_task.evaluate(completer, prompter)
                evaluation.model_behaviors.append(model_behavior)

    def run_baselines(self) -> None:
        model_behaviors = ((eval, mb) for eval in self.baseline_evals for mb in eval.model_behaviors)
        for baseline_config, (baseline_eval, model_behavior) in tqdm(list(itertools.product(self.baseline_configs, model_behaviors)), desc='Running Baselines'):
            baseline = make_baseline(baseline_config, model_behavior)
            baseline.train()
            baseline_eval.baseline_results.append(baseline.test(['test']))

    def update_template_metadata(self):
        for baseline_eval in self.baseline_evals:
            baseline_eval.update_template_metadata()
        
    def filter_by(self, valid_fn: Callable[[TemplateBaselineEvaluation], bool]) -> List[TemplateBaselineEvaluation]:
        return [eval for eval in self.baseline_evals if valid_fn(eval)]

    @property
    def summary(self) -> Dict:
        return dict(
            model_names=[mc.name for mc in self.baseline_model_configs],
            baseline_names=[bc.name for bc in self.baseline_configs],
            template_summaries=[eval.summary for eval in self.baseline_evals]
        )


def save_model_behaviors(filtered_baseline_evals: List[TemplateBaselineEvaluation], config) -> None:
    path = dataset_path(config) / 'all_dataset_behavior'
    for eval in filtered_baseline_evals:
        for model_behavior in eval.model_behaviors:
            assert eval.template_task.template_id
            model_behavior_save_path = path / eval.template_task.template_id
            model_behavior.save(model_behavior_save_path)


def evaluate_baselines(dataset, config):
    template_tasks = dataset.template_tasks
    if not template_tasks:
        return dataset, dict()

    baseline_evals = [TemplateBaselineEvaluation(t) for t in dataset.template_tasks]
    filterer = TemplateBaselineFilterer(baseline_evals, config)
    filterer.collect_model_behavior()
    filterer.run_baselines()
    filterer.update_template_metadata()
    save_model_behaviors(filterer.baseline_evals, config)

    info = dict(
        template_baseline_results=filterer.summary,
    )
    return TemplateDataset(template_tasks), info


def save_questions(dataset: TemplateDataset, save_path: Path) -> None:
    for task in dataset:
        assert task.template.template_id
        save_dir = save_path / task.template.template_id
        save_dir.mkdir(exist_ok=True, parents=True)
        for split, data in task.questions.items():
            data.to_csv(save_dir / f'{split}.csv', index=False)


def save_results(dataset, run_info, save_path, config):
    save_path.mkdir(exist_ok=True, parents=True)
    with open(save_path / 'run_info.json', 'w') as f:
        json.dump(run_info, f, indent=4, ensure_ascii=True)
    dataset.save(save_path, dict())
    save_questions(dataset, save_path)
    for template_task in dataset.template_tasks:
        template = template_task.template
        template_save_dir = save_path / 'dataset_creation' / template.template_id
        template_save_dir.mkdir(exist_ok=True, parents=True)
        with open(template_save_dir / 'template_metadata.json', 'w') as f:
            json.dump(template.metadata, f, indent=4, ensure_ascii=True)
    save_questions(dataset, save_path / 'dataset_creation')
    filtered_dataset = dataset.filter_by_config(config)
    for template_task in filtered_dataset.template_tasks:
        template = template_task.template
        template_save_dir = save_path / 'final_dataset_templates' / template.template_id
        template_save_dir.mkdir(exist_ok=True, parents=True)
        with open(template_save_dir / 'template_metadata.json', 'w') as f:
            json.dump(template.metadata, f, indent=4, ensure_ascii=True)
    save_questions(filtered_dataset, save_path / 'final_dataset_templates')


@hydra.main(config_path='config', config_name='create_dataset.yaml', version_base='1.2')
def main(config):
    run_info: Dict[str, Any] = dict(config=OmegaConf.to_container(config)) # type: ignore
    task_path = template_path(config)
    save_path = dataset_path(config, new=True)

    templates = load_question_templates(task_path / 'question_templates.json')[:config.n_templates_to_evaluate]
    
    templates, filter_info = filter_templates(templates, config)
    run_info['initial_filter'] = filter_info
    
    dataset = create_distributional_shift_dataset(templates, config)
    
    dataset, baseline_filtering_info = evaluate_baselines(dataset, config)
    run_info['baseline_filtering_info'] = baseline_filtering_info
    
    save_results(dataset, run_info, save_path, config)


if __name__ == '__main__':
    main()