import itertools
import pickle
import zipfile
from pathlib import Path
from typing import List, Optional, Sequence, Tuple, Union

import numpy as np
from huggingface_hub.file_download import hf_hub_download
from sentence_transformers.util import cos_sim
from tqdm import tqdm


def etvd(answer_probs: Sequence[float]) -> float:
    """Expected total variation distance of a list of binary probabilities"""
    return np.array([abs(a - b) for a, b in itertools.combinations_with_replacement(answer_probs, 2)]).mean()


def tv_distance_of_preds(preds: np.ndarray, labels: np.ndarray) -> float:
    return abs(preds - labels).mean()


def accuracy(preds: np.ndarray, labels: np.ndarray) -> float:
    return ((preds > 0.5) == (labels > 0.5)).mean()


def expected_log_score(preds: np.ndarray, labels: np.ndarray) -> float:
    preds = preds.clip(.0001, .9999)
    return ((np.log(preds) * labels) + (np.log(1 - preds) * (1 - labels))).mean()


def normalized_expected_log_score(preds: np.ndarray, labels: np.ndarray) -> float:
    return expected_log_score(preds, labels) - expected_log_score(labels, labels)


def kl_divergence(preds: np.ndarray, labels: np.ndarray) -> float:
    return -normalized_expected_log_score(preds, labels)


def positive_fraction(answer_probs: Sequence[float]) -> float:
    return (np.array(answer_probs) > 0.5).astype(float).mean()


def most_similar_idx(embedding: np.ndarray, embedding_table: np.ndarray, k: int, subsample: Optional[int] = None):
    if subsample:
        indices = np.random.choice(embedding_table.shape[0], subsample, replace=False)
        embedding_table = embedding_table[indices, :]
    sims = cos_sim(embedding, embedding_table) # type: ignore
    return np.argpartition(sims.squeeze(), -k)[-k:]


def parse_glove_line(line: str) -> Tuple[str, np.ndarray]:
    word, *number_sts = line.split()
    embedding = np.array([float(s) for s in number_sts])
    return word, embedding


def diversity(embedding_table: np.ndarray) -> float:
    distances = np.linalg.norm(embedding_table[:, None, :] - embedding_table[None, :, :], axis=-1)
    distances = distances.flatten()
    distances = distances[distances != 0]
    if len(distances) == 0:
        return 0
    return distances.mean()


class GloVeEncoder:
    def __init__(self) -> None:
        zip_path = hf_hub_download(repo_id="stanfordnlp/glove", filename="glove.6B.zip")
        assert zip_path is not None
        parent_dir = Path(zip_path).parent
        pickle_file = parent_dir / 'glove_embeddings.pkl'
        if pickle_file.exists():
            with open(pickle_file, 'rb') as f:
                self.embeddings = pickle.load(f)
            return
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(parent_dir)
        glove_file = parent_dir / 'glove.6B.300d.txt'
        embeddings = (parse_glove_line(l) for l in tqdm(glove_file.read_text().splitlines(), desc='Loading GloVe Embeddings'))
        self.embeddings = {word: embedding for word, embedding in embeddings}
        with open(pickle_file, 'wb') as f:
            pickle.dump(self.embeddings, f)

    def encode(self, sentences: Union[str, List[str]]) -> np.ndarray:
        def encode_single(sentence: str):
            embeddings = (self.embeddings.get(word) for word in sentence.split())
            sentence_array = np.stack([e for e in embeddings if e is not None], axis=0)
            return sentence_array.mean(axis=0)
        if isinstance(sentences, str):
            return encode_single(sentences)
        return np.stack([encode_single(q) for q in sentences])


def measure_diversity(texts: List[str]) -> float:
    if not texts:
        return 0
    encoder = GloVeEncoder()
    embeddings = encoder.encode(texts)
    return diversity(embeddings)
