import copy
import itertools
import json
import random
import re
from dataclasses import asdict, dataclass, field
from pathlib import Path
from string import ascii_lowercase
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd

import lm_understanding.metrics as metrics
import lm_understanding.prompting.text as text_utils
from lm_understanding.prompting import (AnswerSet, Completer,
                                        SingleWordAnswerPrompter,
                                        YesNoQuestion, answer_questions)
from lm_understanding.prompting.prompting import Prompter

MAX_COMBINATIONS = 50000


@dataclass
class Substitution:
    """A possible replacement for a variable in a question template."""
    variable_name: str
    value: str


@dataclass
class ValueCombination:
    """A set of values that replace variables in a particular template."""
    substitutions: Tuple[Substitution]

    def __post_init__(self):
        assert len(set(sub.variable_name for sub in self.substitutions)) == len(self.substitutions)

    def __iter__(self):
        return self.substitutions.__iter__()

    @classmethod
    def from_df_row(cls, row: pd.Series) -> 'ValueCombination':
        subs = tuple([Substitution(str(n), v) for n, v in row.items() if str(n) in ascii_lowercase])
        return cls(subs)

    def copy_with_replaced_substitution(self, substitution: Substitution) -> "ValueCombination":
        substitutions = [s for s in self.substitutions if s.variable_name != substitution.variable_name] + [substitution]
        return ValueCombination(tuple(substitutions))

    def replacement_dict(self) -> Dict[str, str]:
        return {sub.variable_name: sub.value for sub in self.substitutions}


@dataclass
class Variables:
    """A set of possible values for the variables in a particular questoin template."""
    possible_values: Dict[str, List[str]]

    def __getitem__(self, *args, **kwargs):
        return self.possible_values.__getitem__(*args, **kwargs)

    def __setitem__(self, *args, **kwargs):
        return self.possible_values.__setitem__(*args, **kwargs)

    def items(self):
        return self.possible_values.items()

    @property
    def variable_names(self) -> List[str]:
        return list(self.possible_values.keys())

    def isolate_value(self, variable_name: str, value: str) -> 'Variables':
        variables = copy.deepcopy(self)
        variables.possible_values[variable_name] = [value]
        return variables

    def exclude_value(self, variable_name: str, value: str) -> 'Variables':
        variables = copy.deepcopy(self)
        variables.possible_values[variable_name] = [v for v in variables.possible_values[variable_name] if v != value]
        return variables

    def isolate_values(self, substitutions: List[Substitution]) -> 'Variables':
        variables = copy.deepcopy(self)
        variable_name = substitutions[0].variable_name
        assert all(s.variable_name == variable_name for s in substitutions)
        variables.possible_values[variable_name] = [s.value for s in substitutions]
        return variables

    def substitution_pairs(self, variable_name: str) -> List[Substitution]:
        return [Substitution(variable_name, value) for value in self[variable_name]]

    def combinations(self) -> List[ValueCombination]:
        pairs = [self.substitution_pairs(variable_name) for variable_name in self.variable_names]
        return [ValueCombination(subs) for subs in itertools.product(*pairs)]

    def sample(self, n: int, seed: int = 0, fast_with_replacement: bool = False):
        return self.sample_with_replacement(n, seed) if fast_with_replacement else self.sample_without_replacement(n, seed)

    def sample_with_replacement(self, n: int, seed: int = 0):
        rng = np.random.RandomState(seed)
        n_values = [len(values) for values in self.possible_values.values()]
        sample_idxs = np.array([rng.choice(v, n, replace=True) for v in n_values])
        vcs = []
        for sample_idx in sample_idxs.T:
            subs = [
                Substitution(variable_name, variable_values[i])
                for i, (variable_name, variable_values)
                in zip(sample_idx, self.possible_values.items())
            ]
            vc = ValueCombination(tuple(subs))
            vcs.append(vc)
        return vcs

    def combination_by_index(self, idx: int) -> ValueCombination:
        substitutions = []
        for variable_name in self.variable_names:
            variable_values = self.possible_values[variable_name]
            idx, variable_idx  = divmod(idx, len(variable_values))
            substitutions.append(Substitution(variable_name, variable_values[variable_idx]))
        return ValueCombination(tuple(substitutions))

    @property
    def possible_combinations(self) -> int:
        combinations = 1
        for values in self.possible_values.values():
            combinations *= len(values)
        return  combinations

    def sample_without_replacement(self, n: int, seed: int = 0):
        possible_combinations = self.possible_combinations
        if n >= possible_combinations:
            return self.combinations()
        random.seed(seed)
        combination_idxs = random.sample(range(0, possible_combinations), n)
        return [self.combination_by_index(idx) for idx in combination_idxs]

    def example_variables(self) -> 'Variables':
        return Variables({k: [f'[{v[0]}]'] for k, v in self.items()})


@dataclass
class QuestionTemplate:
    template: str
    variables: Variables
    template_id: Optional[str] = None
    metadata: Dict = field(default_factory=dict)

    def make_question(self, substitutions: ValueCombination) -> YesNoQuestion:
        question_text = self.template
        for _ in range(2): # to accomodate variables in variable values
            for substitution in substitutions:
                question_text = re.sub(f'\[{substitution.variable_name}\]', str(substitution.value), question_text)  # type: ignore
        return YesNoQuestion(question_text)

    def questions(self, variables: Variables) -> List[YesNoQuestion]:
        return [self.make_question(subs) for subs in variables.combinations()[:MAX_COMBINATIONS]]

    def substitutions_and_questions(self, variables: Variables) -> List[Tuple[ValueCombination, YesNoQuestion]]:
        substitutions = variables.combinations()[:MAX_COMBINATIONS]
        questions = [self.make_question(subs) for subs in substitutions]
        return list(zip(substitutions, questions))

    @property
    def all_questions(self) -> List[YesNoQuestion]:
        return self.questions(self.variables)
    
    @property
    def variable_names(self) -> List[str]:
        return self.variables.variable_names

    @property
    def variables_example(self) -> YesNoQuestion:
        return self.questions(self.variables.example_variables())[0]  # type: ignore

    @property
    def n_questions(self) -> int:
        return len(self.variables.combinations())

    @classmethod
    def from_str(cls, template: str) -> "QuestionTemplate":
        return cls(template, Variables({name: list() for name in text_utils.bracketed_items(template)}))

    def reduce_to_5_variables(self) -> "QuestionTemplate":
        """Remove excess variables from a template."""
        n_final_variables = 5
        template = copy.deepcopy(self)
        n_to_remove = max(len(template.variable_names) - n_final_variables, 0)
        to_remove = template.variable_names[:n_to_remove]
        for variable_name in to_remove:
            offset_map = {k: v for k, v in zip(sorted(template.variable_names)[1:], sorted(template.variable_names)[:-1])}
            template.template = text_utils.revariable_text(re.sub(f'\[{variable_name}\]', template.variables[variable_name][0], template.template))  # type: ignore
            template.variables = Variables({offset_map[k]: v for k, v in template.variables.possible_values.items() if k in offset_map})
            rename_variables = {k: v for k, v in zip(sorted(template.variable_names), 'abcdefghij')}
            template.variables = Variables({rename_variables[k]: v for k, v in template.variables.possible_values.items()})
        return template

    @classmethod
    def from_dict(cls, template_data: dict) -> "QuestionTemplate":
        template_text = template_data['template']
        if 'possible_values' in template_data['variables']:
            variables = Variables(template_data['variables']['possible_values'])
        else:
            variables = Variables(template_data['variables'])
        template_id = template_data.get('template_id')
        metadata = template_data.get('metadata', dict())
        template = cls(template_text, variables, template_id, metadata)
        return template.reduce_to_5_variables()

    @classmethod
    def from_id(cls, template_id: str) -> "QuestionTemplate":
        task_name = '_'.join(template_id.split('_')[:-1])
        template_path = Path('templates') / task_name / '_combined' / 'question_templates.json'
        templates = load_question_templates(template_path)
        template = [t for t in templates if t.template_id == template_id][0]
        return template

    @property
    def words(self) -> str:
        return ' '.join([self.template] + [word for k, values in self.variables.items() for word in values])

    @property
    def valid(self):
        return (
            len(self.variable_names) == 5
            and all(v in self.template for v in ['[a]', '[b]', '[c]', '[d]', '[e]'])
            and all(len(values) >= 10 for _variable, values in self.variables.items())
        )

    def __str__(self) -> str:
        n_values = 0
        combinations = 1
        for name, values in self.variables.items():
            n_values += len(values)
            combinations *= len(values)
        return f"""{self.template}
{self.variables_example}
Total Number of Values: {n_values}
Combinations of Values: {combinations}"""


def measure_template_diversity(templates: List[QuestionTemplate]):
    return metrics.measure_diversity([t.words for t in templates])


def load_question_templates(path: Path) -> List[QuestionTemplate]:
    task_data = json.loads((path).read_text(encoding="utf-8"))
    question_templates = [QuestionTemplate.from_dict(q) for q in task_data['questions']]
    return [t for t in question_templates if t.valid]


@dataclass
class QuestionModelBehavior:
    template: QuestionTemplate
    value_combination: ValueCombination
    split: Optional[str]
    completion: str
    total_yes_no_answer_prob: float
    answer_prob: float
    cot: bool
    valid_answer: bool

    @property
    def template_id(self):
        return self.template.template_id

    @property
    def question_text(self):
        return self.template.make_question(self.value_combination).text

    def as_record(self):
        return dict(
            template_id=self.template_id,
            question_text=self.question_text,
            split=self.split,
            completion=self.completion,
            answer_prob=self.answer_prob,
            total_answer_prob=self.total_yes_no_answer_prob,
            cot=self.cot,
            valid_answer=self.valid_answer,
            **self.value_combination.replacement_dict()
        )
    
    @classmethod
    def from_df_row(cls, row, template: Optional[QuestionTemplate] = None):
        if template is None:
            template = QuestionTemplate.from_id(row.template_id)
        return QuestionModelBehavior(template, ValueCombination.from_df_row(row), row.split, row.completion, row.total_answer_prob, row.answer_prob, row.cot, row.valid_answer)


@dataclass
class TemplateModelBehavior:
    template: QuestionTemplate
    question_behaviors: List[QuestionModelBehavior]
    model_name: str

    model_behavior_filename = 'model_behavior.csv'
    split_summaries_filename = 'split_summaries.json'

    def __post_init__(self):
        self.df = pd.DataFrame.from_records([b.as_record() for b in self.question_behaviors])
        self.split_dfs = self.to_splits() if len(self.split_names) > 1 else None
    
    @property
    def template_id(self) -> Optional[str]:
        return self.template.template_id

    @property
    def split_names(self) -> List[str]:
        assert self.df is not None
        return self.df.split.unique().tolist()

    @property
    def answer_probs(self) -> np.ndarray:
        return self.df.answer_prob.to_numpy()

    @property
    def mean_total_yes_no_answer_prob(self) -> float:
        return np.mean([qb.total_yes_no_answer_prob for qb in self.question_behaviors])

    @property
    def etvd(self) -> float:
        return metrics.etvd(self.answer_probs)  # type: ignore
    
    @property
    def positive_fraction(self) -> float:
        return metrics.positive_fraction(self.answer_probs)  # type: ignore

    @property
    def answer_is_valid(self) -> np.ndarray:
        return self.df.valid_answer.to_numpy()

    def question_behaviors_for_split(self, split: str) -> List[QuestionModelBehavior]:
        return [b for b in self.question_behaviors if b.split == split] 

    def to_splits(self) -> Dict[str, "TemplateModelBehavior"]:
        return {s: TemplateModelBehavior(self.template, self.question_behaviors_for_split(s), self.model_name) for s in self.split_names}

    def questions(self, split: Optional[str] = None) -> List[str]:
        if split:
            assert self.split_dfs is not None
            return self.split_dfs[split].questions()
        return self.df.question_text.to_list()

    def answers(self, split: Optional[str] = None) -> np.ndarray:
        if split:
            assert self.split_dfs is not None
            return self.split_dfs[split].answers()
        return self.df.answer_prob.to_numpy()

    def value_combinations(self, split: Optional[str] = None) -> List[ValueCombination]:
        return [b.value_combination for b in self.question_behaviors if not split or b.split == split]

    @classmethod
    def from_answers(cls, template: QuestionTemplate, variable_values: List[ValueCombination], answers: AnswerSet, model_name: str, split_name: Optional[str] = None, cot: bool = False):
        if cot:
            valid_answers = [a is not None for a in answers.given_answers]
            total_answer_probs = [float(v) for v in valid_answers]
            answer_probs = [round(a) if a is not None else 0.5 for a in answers.given_answers]
        else:
            total_answer_probs = [a.completion.total_answer_prob for a in answers]
            answer_probs = answers.positive_answer_probs
            valid_answers = answers.valid_answers
        question_behaviors = [
            QuestionModelBehavior(template, value_combination, split_name, completion.text, total_answer_prob, answer_prob, cot, valid_answer)
            for value_combination, answer_prob, total_answer_prob, completion, valid_answer
            in zip(variable_values, answer_probs, total_answer_probs, answers.completions, valid_answers)
        ]
        return cls(template, question_behaviors, model_name)

    @classmethod
    def load(cls, results_path: Path):
        model_behavior = pd.read_csv(results_path / cls.model_behavior_filename)
        model_name = json.loads((results_path / cls.split_summaries_filename).read_text())['summary']['model_name']
        question_behaviors = []
        template = QuestionTemplate.from_id(model_behavior.template_id[0])
        question_behaviors = [QuestionModelBehavior.from_df_row(row, template) for _, row in model_behavior.iterrows()]
        return cls(template, question_behaviors, model_name)

    @property
    def summary(self) -> Dict:
        return {
                'template_id': self.template.template_id,
                'model_name': self.model_name,
                'n_samples': len(self.question_behaviors),
                'etvd': self.etvd,
                'positive_fraction': self.positive_fraction,
                'valid_fraction': self.answer_is_valid.mean(),
                'total_answer_prob': self.mean_total_yes_no_answer_prob,
            }

    def save(self, save_path: Path):
        save_path.mkdir(exist_ok=True, parents=True)
        self.df.to_csv(save_path / self.model_behavior_filename, index_label='question_id')

        summary_data = dict(summary=self.summary)
        if len(self.split_names) > 1:
            summary_data.update({s: mb.summary for s, mb in self.to_splits().items()})
        with open(save_path / self.split_summaries_filename, 'w') as f:
            json.dump(summary_data, f, indent=4, ensure_ascii=True)


def answer_template_questions(
        template: QuestionTemplate,
        variables: Variables,
        completer: Completer,
        n_questions: int,
        split_name: Optional[str] = None,
        prompter: Optional[Prompter] = None,
        seed: int = 0,
        sample_fast: bool = False
    ) -> TemplateModelBehavior:
    variable_values = variables.sample(n_questions, seed, fast_with_replacement=sample_fast)
    questions = [template.make_question(v) for v in variable_values]
    if prompter is None:
        prompter = SingleWordAnswerPrompter(lambda q: q.text, 'Answer:')
    answers = answer_questions(questions, prompter, completer)  # type: ignore
    return TemplateModelBehavior.from_answers(template, variable_values, answers, completer.name, split_name, not completer.multiple_choice)


def make_question_df(template: QuestionTemplate, value_combinations: List[ValueCombination]):
    questions = [template.make_question(vc) for vc in value_combinations]
    records = [
        dict(
            template_id=template.template_id,
            question_text=q.text,
            **{sub.variable_name: sub.value for sub in combination}
        ) for q, combination in zip(questions, value_combinations)
    ]
    return pd.DataFrame.from_records(records)


@dataclass
class TemplateTask:
    template: QuestionTemplate
    questions: Dict[str, pd.DataFrame]

    @property
    def template_id(self):
        return self.template.template_id

    @property
    def split_names(self) -> List[str]:
        if self.questions is None:
            return []
        return list(self.questions.keys())

    @property
    def summary(self):
        return dict(
            **asdict(self.template),
            n_questions={s: len(qs) for s, qs in self.questions.items()}
        )

    def questions_to_eval(self, train_questions: Optional[int] = None, test_questions: Optional[int] = None, splits: Optional[List[str]] = None) -> pd.DataFrame:
        questions_to_eval = []
        for split_name, questions in self.questions.items():
            if splits and (split_name not in splits):
                continue
            split_questions_to_eval = questions.copy()
            split_questions_to_eval['split'] = split_name
            if split_name == 'train' and train_questions:
                split_questions_to_eval = split_questions_to_eval.iloc[:train_questions]
            if split_name == 'test' and test_questions:
                split_questions_to_eval = split_questions_to_eval.iloc[:test_questions]
            questions_to_eval.append(split_questions_to_eval)
        return pd.concat(questions_to_eval)

    def model_behavior(self, all_questions_to_eval: pd.DataFrame, answers: AnswerSet, model_name: str) -> TemplateModelBehavior:
        question_behaviors = []
        for split_name in all_questions_to_eval.split.unique():
            question_in_split = all_questions_to_eval.split == split_name
            question_data = all_questions_to_eval[question_in_split] 
            split_answers = AnswerSet([a for a, in_split in zip(answers.answers, question_in_split) if in_split])
            variable_values = [ValueCombination.from_df_row(row) for _, row in question_data.iterrows()]
            behavior = TemplateModelBehavior.from_answers(
                template=self.template,
                variable_values=variable_values,
                answers=split_answers,
                model_name=model_name,
                split_name=split_name,
                cot = not answers.is_multiple_choice
            )
            question_behaviors.extend(behavior.question_behaviors)
        return TemplateModelBehavior(self.template, question_behaviors, model_name)

    def evaluate(
            self,
            completer: Completer,
            prompter: Prompter,
            splits: Optional[List[str]] = None,
            train_questions: Optional[int] = None,
            test_questions: Optional[int] = None,
            verbose: bool = True) -> TemplateModelBehavior:
        all_questions_to_eval = self.questions_to_eval(train_questions, test_questions, splits)
        questions = [YesNoQuestion(t) for t in all_questions_to_eval.question_text]
        answers = answer_questions(questions, prompter, completer, progress_bar=verbose) # type: ignore
        return self.model_behavior(all_questions_to_eval, answers, completer.name)

    @property
    def average_words_per_question(self) -> float:
        words = []
        for question_df in self.questions.values():
            words.extend(len(q.split()) for q in question_df['question_text'])
        return np.array(words).mean()

    @property
    def fraction_words_changed_by_substitutions(self):
        template_words = len([w for w in self.template.template.split() if w[0] != '['])
        average_question_length = self.average_words_per_question
        return (average_question_length - template_words) / average_question_length


@dataclass
class TemplateDataset:
    template_tasks: List[TemplateTask]
    metadata: Optional[Dict[str, Any]] = None

    def __len__(self) -> int:
        return len(self.template_tasks)

    def __getitem__(self, idx_or_id):
        if isinstance(idx_or_id, str):
            return [template_task for template_task in self.template_tasks if template_task.template.template_id == idx_or_id][0]
        return self.template_tasks[idx_or_id]

    def __iter__(self):
        return iter(self.template_tasks)
    
    @property
    def templates(self) -> List[QuestionTemplate]:
        return [t.template for t in self.template_tasks]

    @property
    def template_ids(self) -> List[str]:
        return [t.template_id for t in self.templates]  # type: ignore

    def get_template_metadata(self, template_id: str) -> Dict:
        return [t for t in self.templates if t.template_id == template_id][0].metadata 

    def filter_by_baseline_performance(self, n_to_keep: int) -> "TemplateDataset":
        def get_score(template_task):
            metadata = template_task.template.metadata
            baseline_results: Dict = next(iter(metadata['baseline_filtering']['baseline_results'].values()))
            return baseline_results['LinearRegression']['kl_divergence']['test']
        scores = [get_score(t) for t in self.template_tasks]
        keep_idx = np.argsort(scores)[-n_to_keep:]
        template_tasks = [self.template_tasks[i] for i in keep_idx]
        return TemplateDataset(template_tasks)

    def filter_from_config(self, config) -> "TemplateDataset":
        return self.filter_by_baseline_performance(
            n_to_keep=config.dataset_construction.templates_per_topic
        )

    def evaluate(self, *args, **kwargs):
        return [t.evaluate(*args, **kwargs) for t in self.template_tasks]

    def save(self, save_path: Path, metadata: Dict):
        save_path.mkdir(exist_ok=True, parents=True)
        template_data = [template_task.summary for template_task in self]
        data = dict(
            dataset_construction_metadata=metadata,
            templates=template_data
        )
        with open(save_path / 'dataset.json', 'w') as f:
                json.dump(data, f, indent=4, ensure_ascii=False)

    @classmethod
    def load(cls, dataset_path: Path) -> 'TemplateDataset':
        data = json.loads((dataset_path / 'dataset.json').read_text())['templates']
        try:
            metadata = json.loads((dataset_path / 'run_info.json').read_text())
        except:
            metadata = None
        template_questions: Dict[str, Dict[str, pd.DataFrame]] = dict()
        for template_dir in (dataset_path / 'dataset_creation').iterdir():
            if not template_dir.is_dir():
                continue
            if not (template_dir / 'train.csv').exists():
                continue
            template_questions[template_dir.stem] = {file.stem: pd.read_csv(file) for file in template_dir.iterdir() if file.suffix == '.csv'}
        dataset = []
        for template_data in data:
            template = QuestionTemplate.from_dict(template_data)
            assert template.template_id is not None
            questions = template_questions.get(template.template_id)
            if questions:
                dataset.append(TemplateTask(template, questions))
        return cls(dataset, metadata)
