import random
from typing import Callable, Dict, List, Tuple, Union

import hydra
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from tqdm import tqdm

from lm_understanding.data import dataset_path, model_behavior_path
from lm_understanding.explanations.explanations import LocalExplanationSet
from lm_understanding.metrics import GloVeEncoder
from lm_understanding.question_template import (QuestionModelBehavior,
                                                QuestionTemplate,
                                                TemplateDataset,
                                                TemplateModelBehavior,
                                                TemplateTask)

Embedder = Union[SentenceTransformer, GloVeEncoder]

def get_clusters(embedder: Embedder, words: List[str]) -> np.ndarray:
    embeddings = embedder.encode(words, show_progress_bar=False)  # type: ignore
    kmeans = KMeans(n_clusters=3, n_init='auto')
    kmeans.fit(embeddings)  # type: ignore
    return kmeans.labels_


def load_dataset(config) -> TemplateDataset:
    return TemplateDataset.load(dataset_path(config)).filter_from_config(config)

Scores = Dict[str, Dict[str, float]]

def word_score_dict(embedder: Embedder, template: QuestionTemplate, score_fn: Callable) -> Scores:
    return {variable: {v: s for v, s in zip(values, score_fn(embedder, values))} for variable, values in template.variables.items()}


def scores_for_df_row(scores: Scores, row: pd.Series) -> list:
    return [scores[variable][str(row[variable])] for variable in ['a', 'b', 'c', 'd', 'e']]

def score_dict_for_df_row(scores: Scores, row: pd.Series) -> dict:
    return {variable: round(scores[variable][str(row[variable])], 2) for variable in ['a', 'b', 'c', 'd', 'e']}

def variable_values(scores: Scores, row: pd.Series) -> dict:
    return {variable: (row[variable], round(scores[variable][str(row[variable])], 2)) for variable in ['a', 'b', 'c', 'd', 'e']}

def create_inputs(question_df: pd.DataFrame, scores: Scores) -> np.ndarray:
    return np.stack(question_df.apply(lambda row: scores_for_df_row(scores, row), axis=1))  # type: ignore


class Node:
    def __init__(self, attribute=None, label=None, children=None, node_name=None):
        self.attribute = attribute  # attribute to split on
        self.label = label  # label if this is a leaf
        self.children = children or {}  # children nodes, mapped by attribute value
        self.node_name = node_name


def random_tree(depth, attributes, parent_name='', max_depth=5):
    if depth >= max_depth:
        return Node(label=random.uniform(0, 1), node_name=parent_name)

    attribute = random.choice(attributes)
    remaining_attributes = [a for a in attributes if a != attribute]

    split_value = random.choice([0, 1, 2])
    remaining_values = [x for x in [0, 1, 2] if x != split_value]

    return Node(
        attribute=attribute,
        children={
            split_value: random_tree(depth+1, remaining_attributes, parent_name=parent_name+'L', max_depth=max_depth),
            tuple(remaining_values): random_tree(depth+1, remaining_attributes, parent_name=parent_name+'R', max_depth=max_depth)
        },
        node_name=parent_name
    )


def explain(node, sample, explanation="Decision path:"):
    if node.label is not None:
        return f"{explanation}\nFinal label: {node.label:.2f}"

    attribute_value = sample[node.attribute]
    next_node = node.children.get(attribute_value, None)
    
    if next_node is None:
        # Check if the attribute_value is in any tuple keys
        for key in node.children.keys():
            if isinstance(key, tuple) and attribute_value in key:
                next_node = node.children[key]
                break

    return explain(
        next_node,
        sample,
        explanation=f"{explanation}\nAttribute {node.attribute} is {attribute_value}, go to node {next_node.node_name or 'leaf'}"
    )

def predict(sample, tree):
    node = tree
    while node.label is None:
        branch_key = sample[node.attribute]
        
        if branch_key in node.children:
            node = node.children[branch_key]
        else:
            for key in node.children.keys():
                if isinstance(key, tuple) and branch_key in key:
                    node = node.children[key]
                    break
    return node.label

class SyntheticDecisionTree:
    def __init__(self, embedder: Embedder):
        self.embedder = embedder

    def train(self, template_task: TemplateTask, n_train_samples: int = 100) -> None:
        self.variable_value_clusters = word_score_dict(self.embedder, template_task.template, get_clusters)
        self.tree_classifier = random_tree(0, ['a', 'b', 'c', 'd', 'e'], max_depth=3)

    def eval(self, template_task: TemplateTask) -> Tuple[TemplateModelBehavior, LocalExplanationSet, str]:
        assert template_task.template_id
        results = []
        explanations = []
        for split in template_task.split_names:
            question_df = template_task.questions[split]
            for _, row in question_df.iterrows():
                x = score_dict_for_df_row(self.variable_value_clusters, row)
                y = predict(x, self.tree_classifier)
                row['completion'] = ''
                row['total_answer_prob'] = 1
                row['answer_prob'] = y
                row['cot'] = False
                row['valid_answer'] = True
                row['split'] = split
                if split == 'train':
                    explanation = f'Variable Values: {variable_values(self.variable_value_clusters, row)}\n{explain(self.tree_classifier, x)}'
                    explanations.append(explanation)
                results.append(QuestionModelBehavior.from_df_row(row, template_task.template))
        
        model_behavior = TemplateModelBehavior(template_task.template, results, str(self))
        explanation_set = LocalExplanationSet(
            template_task.template_id,
            model_behavior.questions('train'),
            model_behavior.answers('train').tolist(),
            explanations
        )
        return model_behavior, explanation_set, str(self.variable_value_clusters)


@hydra.main(config_path='../../config', config_name='model_behavior.yaml', version_base='1.2')
def main(config):
    assert config.model.name in ['synthetic_linear_model', 'synthetic_decision_tree']
    embedder = SentenceTransformer('all-distilroberta-v1')
    model = SyntheticDecisionTree(embedder)
    dataset = load_dataset(config)
    save_path = model_behavior_path(config)
    for template_task in tqdm(dataset.template_tasks):
        assert template_task.template_id
        save_dir = save_path / template_task.template_id
        model.train(template_task)
        model_behavior, local_explanations, global_explanation = model.eval(template_task)
        model_behavior.save(save_dir)
        explanation_save_dir = save_dir / 'explanations'
        explanation_save_dir.mkdir(exist_ok=True)
        local_explanations.save(explanation_save_dir / 'weights_and_activations.csv')
        (explanation_save_dir / 'weights_and_activations_global.txt').write_text(global_explanation)

if __name__ == '__main__':
    main()