from abc import ABC, abstractmethod
from argparse import Namespace
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Protocol, Union

import numpy as np
import torch
from sentence_transformers import SentenceTransformer

from lm_understanding.explanations.explanations import (LocalExplanation,
                                                        LocalExplanationSet)
from lm_understanding.metrics import (GloVeEncoder, accuracy,
                                      expected_log_score, kl_divergence,
                                      most_similar_idx,
                                      normalized_expected_log_score,
                                      tv_distance_of_preds)
from lm_understanding.question_template import TemplateModelBehavior


@dataclass
class BaselineResults:
    preds: Dict[str, List[float]]
    labels: Dict[str, List[float]]
    model_name: str
    template_id: str
    baseline: str
    prediction_info: Dict[str, Dict[str, list]]

    @property
    def preds_and_labels(self):
        return {split: (preds, self.labels[split]) for split, preds in self.preds.items()}

    def eval_metric(self, metric: Callable[[np.ndarray, np.ndarray], float]) -> Dict[str, float]:
        return {split: metric(np.array(preds), np.array(labels)) for split, (preds, labels) in self.preds_and_labels.items()}

    @property
    def accuracies(self) -> Dict[str, float]:
        return self.eval_metric(accuracy)

    @property
    def tv_distances(self) -> Dict[str, float]:
        return self.eval_metric(tv_distance_of_preds)

    @property
    def log_scores(self) -> Dict[str, float]:
        return self.eval_metric(expected_log_score)

    @property
    def normalized_log_scores(self) -> Dict[str, float]:
        return self.eval_metric(normalized_expected_log_score)

    @property
    def kl_divergences(self) -> Dict[str, float]:
        return self.eval_metric(kl_divergence)

    @property
    def split_names(self) -> List[str]:
        return list(self.preds.keys())

    @property
    def summary(self):
        return dict(
            model_name=self.model_name,
            baseline_name=self.baseline,
            template_id=self.template_id,
            accuracy=self.accuracies,
            tv_distance=self.tv_distances,
            log_score=self.log_scores,
            kl_divergence=self.kl_divergences,
        )

    def as_dict(self) -> Dict[str, Any]:
        data: Dict[str, Any] = self.summary
        data['predictions'] = self.preds
        data['labels'] = self.labels
        return data
    
    def as_records(self) -> List[Dict[str, Any]]:
        records = []
        for split in self.split_names:
            for i, (pred, label) in enumerate(zip(self.preds[split], self.labels[split])):
                records.append(dict(
                    template_id=self.template_id,
                    split=split,
                    question_idx=i,
                    prediction=pred,
                    label=label
                ))
            for info_key, info_values in self.prediction_info[split].items():
                for record, value in zip(records, info_values):
                    record[info_key] = value
        return records


class Encoder(Protocol):
    def encode(self, sentences: Union[str, List[str]]) -> Union[List[torch.Tensor], torch.Tensor, np.ndarray]:
        ...


def load_encoder(name: str) -> Optional[Encoder]:
    if name == 'none':
        return None
    elif name in ['all-mpnet-base-v2']:
        return SentenceTransformer(name)
    elif name in ['GloVe']:
        return GloVeEncoder()
    else:
        raise NotImplementedError(f'{name} is not an implemented embedder')


def encode_questions(encoder: Optional[Encoder], model_behavior: TemplateModelBehavior) -> Optional[Dict]:
     if encoder is None:
         return None
     return {
        split: encoder.encode(model_behavior.questions(split))  # type: ignore
        for split in model_behavior.split_names
    }


class ExampleRetriever(ABC):
    def __init__(self, model_behavior: TemplateModelBehavior, retriever_config: Namespace, *args, **kwargs):
        pass

    @abstractmethod
    def __call__(self, split_name: str, question_idx: int, k: int) -> List[int]:
        ...


class RandomRetriever(ExampleRetriever):
    def __init__(self, model_behavior: TemplateModelBehavior, *args, **kwargs):
        self.n_train_questions = len(model_behavior.questions('train'))
        self.rng = np.random.RandomState(seed=0)

    def __call__(self, split_name: str, question_idx: int, k: int) -> List[int]:
        indices = np.arange(self.n_train_questions)
        if split_name == 'train':
            indices = np.delete(indices, question_idx)
        return self.rng.choice(indices, k, replace=False).tolist()


class NearestNeighborRetriever(ExampleRetriever):
    def __init__(self, model_behavior: TemplateModelBehavior, encoded_questions: Optional[Dict[str, np.ndarray]], *args, **kwargs):
        assert encoded_questions is not None
        self.encoded_questions = encoded_questions

    def __call__(self, split_name: str, question_idx: int, k: int) -> List[int]:
        embedding = self.encoded_questions[split_name][question_idx]
        if split_name == 'train':
            return [i for i in most_similar_idx(embedding, self.encoded_questions['train'], k + 1) if i != question_idx]
        return most_similar_idx(embedding, self.encoded_questions['train'], k).tolist()


def make_retriever(retriever_config: Namespace, model_behavior: TemplateModelBehavior, encoded_questions: Optional[Dict[str, np.ndarray]]) -> Optional[ExampleRetriever]:
    if retriever_config.name == 'none':
        return None
    if retriever_config.name == 'random':
        return RandomRetriever(model_behavior)
    if retriever_config.name == 'nearest_neighbor':
        return NearestNeighborRetriever(model_behavior, encoded_questions)
    raise NotImplementedError()


class Baseline(ABC):
    def __init__(self, model_behavior: TemplateModelBehavior, baseline_config: Namespace, explainer: Optional[LocalExplanationSet] = None, **kwargs):
        self.model_behavior = model_behavior
        encoder = load_encoder(baseline_config.embedder.name)
        self._encoded_questions = encode_questions(encoder, model_behavior)
        self._retriever = make_retriever(baseline_config.retriever, model_behavior, self.encoded_questions)
        self._local_explainer = explainer

    @property
    def encoded_questions(self) -> Optional[Dict[str, np.ndarray]]:
        return self._encoded_questions

    @property
    def retriever(self) -> Optional[ExampleRetriever]:
        return self._retriever

    @property
    def local_explainer(self) -> Optional[LocalExplanationSet]:
        return self._local_explainer

    @property
    def global_explanation(self) -> Optional[str]:
        return self.local_explainer.global_explanation if self.local_explainer else None

    @property
    def template_id(self) -> str:
        assert self.model_behavior.template_id
        return self.model_behavior.template_id

    @abstractmethod
    def train(self, *args, **kwargs):
        ...

    def predict(self, split: str, question_idx: int) -> float:
        ...

    def get_question(self, split: str, question_idx: int) -> str:
        return self.model_behavior.questions(split)[question_idx]

    def few_shot_examples(self, question_idx: int, split_name: str, k: int) -> List[LocalExplanation]:
        if not k:
            return []
        assert self.retriever is not None
        few_shot_example_idx = self.retriever(split_name, question_idx, k)
        assert self.local_explainer is not None
        return [self.local_explainer.explain(i) for i in few_shot_example_idx]

    def split_predictions(self, split_name: str, *args, **kwargs) -> Dict[str, List]:
        return dict(predictions=[self.predict(split_name, i) for i, _q in enumerate(self.model_behavior.questions(split_name))])

    def test(self, split_names: Optional[List[str]] = None) -> BaselineResults:
        preds = dict()
        labels = dict()
        prediction_info = dict()
        if split_names is None:
            split_names = [name for name in self.model_behavior.split_names if name != 'train']
        for split_name in split_names:
            output = self.split_predictions(split_name)
            preds[split_name] = output.pop('predictions')
            prediction_info[split_name] = output
            labels[split_name] = self.model_behavior.answers(split_name)[:len(preds[split_name])]
        return BaselineResults(
            preds=preds,
            labels=labels,
            model_name=self.model_behavior.model_name,
            template_id=self.template_id,
            baseline=self.__class__.__name__,
            prediction_info=prediction_info
        )
