import copy
from collections import defaultdict
from typing import Callable, Dict, List, Tuple, Union

import hydra
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from tqdm import tqdm

from lm_understanding.data import dataset_path, model_behavior_path
from lm_understanding.explanations.explanations import LocalExplanationSet
from lm_understanding.metrics import GloVeEncoder
from lm_understanding.question_template import (QuestionModelBehavior,
                                                QuestionTemplate,
                                                TemplateDataset,
                                                TemplateModelBehavior,
                                                TemplateTask)

Embedder = Union[SentenceTransformer, GloVeEncoder]

def get_scores(embedder: Embedder, words: List[str]) -> np.ndarray:
    embeddings = embedder.encode(words, show_progress_bar=False)  # type: ignore
    pca = PCA(n_components=1)
    pca.fit(embeddings)  # type: ignore
    principal_component = pca.components_[0]
    return embeddings.dot(principal_component) * 4 # type: ignore


def get_clusters(embedder: Embedder, words: List[str]) -> np.ndarray:
    embeddings = embedder.encode(words, show_progress_bar=False)  # type: ignore
    kmeans = KMeans(n_clusters=3)
    kmeans.fit(embeddings)  # type: ignore
    return kmeans.labels_


def load_dataset(config) -> TemplateDataset:
    return TemplateDataset.load(dataset_path(config)).filter_from_config(config)

Scores = Dict[str, Dict[str, float]]

def word_score_dict(embedder: Embedder, template: QuestionTemplate, score_fn: Callable) -> Scores:
    return {variable: {v: s for v, s in zip(values, score_fn(embedder, values))} for variable, values in template.variables.items()}

filter_out_test_variables = lambda variable_dict: [(k, v) for i, (k, v) in enumerate(variable_dict.items()) if i < 12]
sort = lambda variable_list: sorted(variable_list, key=lambda t: t[1])

def variable_encoding_explanation(scores: Scores) -> str:
    sort = lambda variable_list: sorted(variable_list, key=lambda t: t[1])
    as_str = lambda variable_list: str([f'{k}: {v:.2f}' for k, v in variable_list])
    return str({variable_name: as_str(sort(filter_out_test_variables(variable_dict))) for variable_name, variable_dict in scores.items()})

def variable_spectrum_explanation(scores: Scores, weights: np.ndarray) -> str:
    select_low_hi = lambda variable_list: (variable_list[0][0], variable_list[-1][0])
    apply_sign = lambda low_hi_vars, w: low_hi_vars if w > 0 else (low_hi_vars[1], low_hi_vars[0]) 
    as_str = lambda low_hi_vars: f"From '{low_hi_vars[0]}' (inclining toward No) to '{low_hi_vars[1]}' (inclining toward Yes)"
    return str({variable_name: as_str(apply_sign(select_low_hi(sort(filter_out_test_variables(variable_dict))), w)) for w, (variable_name, variable_dict) in zip(weights, scores.items())})

def weights_ranked_by_influence(weights):
    ranked = reversed(sorted([(c, abs(w)) for c, w in zip('abcde', weights)], key=lambda t: t[1]))
    return [c for c, w in ranked]

def scores_for_df_row(scores: Scores, row: pd.Series) -> list:
    return [scores[variable][str(row[variable])] for variable in ['a', 'b', 'c', 'd', 'e']]

def score_dict_for_df_row(scores: Scores, row: pd.Series) -> dict:
    return {row[variable]: round(scores[variable][str(row[variable])], 2) for variable in ['a', 'b', 'c', 'd', 'e']}


def sigmoid(x):
    return 1/(1 + np.exp(-x))

def make_question_behavior(template: QuestionTemplate, split: str, row: pd.Series, x: np.ndarray, y: float):
    row = copy.deepcopy(row)
    row['completion'] = ''
    row['total_answer_prob'] = 1
    row['answer_prob'] = y
    row['cot'] = False
    row['valid_answer'] = True
    row['split'] = split
    return QuestionModelBehavior.from_df_row(row, template)

class SyntheticLinearModel:
    def __init__(self, embedder: Embedder):
        self.embedder = embedder
        self.weights = np.random.exponential(1, 5)
    
    def predict(self, sample: np.ndarray) -> float:
        return sigmoid(sum(self.weights * sample))

    def train(self, template_task: TemplateTask) -> None:
        self.variable_value_scores = word_score_dict(self.embedder, template_task.template, get_scores)

    def weights_and_activations_explanation(self, row: pd.Series) -> str:
        x = np.array(scores_for_df_row(self.variable_value_scores, row))
        pre_sigmoid = sum(self.weights * x)
        return '\n'.join([
            f'Variable Scores: {score_dict_for_df_row(self.variable_value_scores, row)}',
            ' + '.join(f'({w:.2f} * {a:.2f})' for w, a in zip(self.weights, x)) + f' = {pre_sigmoid:.2f}',
            f'Sigmoid({pre_sigmoid:.2f}) = {sigmoid(pre_sigmoid):.2f}'
        ])
    
    def natural_language_explanation(self, row: pd.Series) -> str:
        x = np.array(scores_for_df_row(self.variable_value_scores, row))
        activations = self.weights * x
        positive_activations = [row[c] for c, a in zip('abcde', activations) if a > 0]
        negative_activations = [row[c] for c, a in zip('abcde', activations) if a < 0]
        return '\n'.join([
            f'The variables {positive_activations} increased the likelihood of yes, while {negative_activations} decreased it.',
        ])

    @property
    def global_explanation(self) -> Dict[str, str]:
        return dict(
            weights_and_activations=f'To get the probability of a yes answer, this model assigns scores to each variable word in the question, take a weighted sum, then applies the sigmoid function. The weights are {[round(w, 2) for w in self.weights]}. The scores for each variable represent variation along the primary axis of semantic meaning. For reference, here are some examples of words scored along that axis: {variable_encoding_explanation(self.variable_value_scores)}',
            natural_language=f'To get the probability of a yes answer, the model evaluates each variable word along a qualitative spectrum, and assigns a score to each. Here are the ends of the spectrums: {variable_spectrum_explanation(self.variable_value_scores, self.weights)}. Each variable has a different degree of influence on the final answer. In order from most influential to least influential, they are {weights_ranked_by_influence(self.weights)}'
        )

    def eval(self, template_task: TemplateTask) -> Tuple[TemplateModelBehavior, Dict[str, LocalExplanationSet]]:
        assert template_task.template_id
        results = []
        explanations = defaultdict(list)
        for split in template_task.split_names:
            question_df = template_task.questions[split]
            for _, row in question_df.iterrows():
                x = np.array(scores_for_df_row(self.variable_value_scores, row))
                y = self.predict(x)
                results.append(make_question_behavior(template_task.template, split, row, x, y))
                if split == 'train':
                    explanations['weights_and_activations'].append(self.weights_and_activations_explanation(row))
                    explanations['natural_language'].append(self.natural_language_explanation(row))

        model_behavior = TemplateModelBehavior(template_task.template, results, str(self))
        explanation_sets = dict()
        for explanation_name in explanations.keys():
            explanation_sets[explanation_name] = LocalExplanationSet(
            template_task.template_id,
            model_behavior.questions('train'),
            model_behavior.answers('train').tolist(),
            explanations[explanation_name],
            self.global_explanation[explanation_name]
        )
        return model_behavior, explanation_sets


@hydra.main(config_path='../../config', config_name='model_behavior.yaml', version_base='1.2')
def main(config):
    assert config.model.name in ['synthetic_linear_model']
    embedder = SentenceTransformer('all-distilroberta-v1')
    model = SyntheticLinearModel(embedder)
    dataset = load_dataset(config)
    save_path = model_behavior_path(config)
    for template_task in tqdm(dataset.template_tasks):
        assert template_task.template_id
        save_dir = save_path / template_task.template_id
        model.train(template_task)
        model_behavior, local_explanations = model.eval(template_task)
        model_behavior.save(save_dir)
        explanation_save_dir = save_dir / 'explanations'
        explanation_save_dir.mkdir(exist_ok=True)
        for explanation_name, explanations in local_explanations.items():
            explanations.save(explanation_save_dir / f'{explanation_name}.csv')

if __name__ == '__main__':
    main()