import os
from functools import lru_cache

import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from .data import infer_preprocess
from .download_utils import (RemoteFileMetadata,
                                            check_download_resource)
from .model import BERTFlatClassifier
from tqdm import trange
from transformers import AutoTokenizer
from transformers import logging as trf_logging

PRETRAIN_ECHECKERS = {
    "echecker_clotho_audiocaps_base": (
        "https://github.com/blmoistawinde/fense/releases/download/V0.1/echecker_clotho_audiocaps_base.ckpt",
        "1a719f090af70614bbdb9f9437530b7e133c48cfa4a58d964de0d47fc974a2fa",
    ),
    "echecker_clotho_audiocaps_tiny": (
        "https://github.com/blmoistawinde/fense/releases/download/V0.1/echecker_clotho_audiocaps_tiny.ckpt",
        "90ed0ac5033ec497ec66d4f68588053813e085671136dae312097c96c504f673",
    ),
    "none": (None, None),
}


def load_pretrain_echecker(
    echecker_model, device="cuda", use_proxy=False, proxies=None
):
    trf_logging.set_verbosity_error()  # suppress loading warnings
    url, checksum = PRETRAIN_ECHECKERS[echecker_model]
    remote = RemoteFileMetadata(
        filename=f"{echecker_model}.ckpt", url=url, checksum=checksum
    )
    file_path = check_download_resource(remote, use_proxy, proxies)
    # print(f'file_path: {file_path}')
    # /root/.fense_data
    model_states = torch.load(file_path)
    clf = BERTFlatClassifier(
        model_type=model_states["model_type"], num_classes=model_states["num_classes"]
    )
    dict_new = clf.state_dict().copy()
    trained_list = [
        i
        for i in model_states["state_dict"].keys()
        if not ("encoder.embeddings.position_ids" in i)
    ]
    for i in range(len(trained_list)):
        dict_new[trained_list[i]] = model_states["state_dict"][trained_list[i]]
    clf.load_state_dict(dict_new)
    clf.eval()
    clf.to(device)
    return clf


class Evaluator:
    def __init__(
        self,
        batch_size=32,
        device="cuda",
        sbert_model="paraphrase-TinyBERT-L6-v2",
        echecker_model="echecker_clotho_audiocaps_base",
        error_threshold=0.9,
        penalty=0.9,
        use_proxy=False,
        proxies=None,
    ):
        # assert sbert_model in {'paraphrase-MiniLM-L6-v2', 'paraphrase-TinyBERT-L6-v2', 'paraphrase-mpnet-base-v2'}
        assert echecker_model in PRETRAIN_ECHECKERS
        self.batch_size = batch_size
        self.device = device
        self.sbert_model = sbert_model
        self.echecker_model = echecker_model
        self.error_threshold = error_threshold
        self.penalty = penalty

        os.environ["TOKENIZERS_PARALLELISM"] = "false"

        self.sbert = SentenceTransformer(sbert_model, device=device)
        if echecker_model != "none":
            self.echecker = load_pretrain_echecker(
                echecker_model, device, use_proxy, proxies
            )
            self.echecker_tokenizer = AutoTokenizer.from_pretrained(
                self.echecker.model_type
            )
            self.echecker.to(device)
            self.echecker.eval()

    def encode_sents_sbert(self, sents, batch_size=32):
        return self.sbert.encode(
            sents,
            convert_to_tensor=True,
            normalize_embeddings=True,
            batch_size=batch_size,
            show_progress_bar=False,
        )

    @lru_cache(maxsize=32)  # reuse cache if encode the same sentence
    def encode_sent_sbert(self, sent):
        return self.sbert.encode(
            sent,
            convert_to_tensor=True,
            normalize_embeddings=True,
            show_progress_bar=False,
        )

    def detect_error_sents(self, sents, batch_size=32):
        if len(sents) <= batch_size:
            batch = infer_preprocess(self.echecker_tokenizer, sents, max_len=64)
            for k, v in batch.items():
                batch[k] = v.to(self.device)
            with torch.no_grad():
                logits = self.echecker(**batch)
                probs = torch.sigmoid(logits).detach().cpu().numpy()
        else:
            probs = []
            for i in trange(0, len(sents), batch_size):
                batch = infer_preprocess(
                    self.echecker_tokenizer, sents[i : i + batch_size], max_len=64
                )
                for k, v in batch.items():
                    batch[k] = v.to(self.device)
                with torch.no_grad():
                    batch_logits = self.echecker(**batch)
                    batch_probs = (
                        torch.sigmoid(batch_logits).detach().cpu().numpy()[:, -1]
                    )
                probs.append(batch_probs)
            probs = np.concatenate(probs)
        return (probs > self.error_threshold).astype(float)

    @lru_cache(maxsize=32)  # reuse cache if infer with the same sentence
    def detect_error_sent(self, sent, return_error_prob=False):
        batch = infer_preprocess(self.echecker_tokenizer, [sent], max_len=64)
        for k, v in batch.items():
            batch[k] = v.to(self.device)
        with torch.no_grad():
            logits = self.echecker(**batch)
            probs = torch.sigmoid(logits).detach().cpu().numpy()
        has_error = probs[0, -1] > self.error_threshold
        if return_error_prob:
            return has_error, probs[0, -1]
        else:
            return has_error

    def corpus_score(self, cands, list_refs, agg_score="mean"):
        assert len(cands) == len(list_refs)
        assert agg_score in {"none", "mean", "max"}
        rng_ids = [0]
        all_refs = []
        for lst in list_refs:
            rng_ids.append(rng_ids[-1] + len(lst))
            all_refs.extend(lst)
        print("Encoding sentences")
        emb_cands = self.encode_sents_sbert(cands, self.batch_size)
        emb_refs = self.encode_sents_sbert(all_refs, self.batch_size)
        sim_scores = [
            (emb_cands[i] @ emb_refs[rng_ids[i] : rng_ids[i + 1]].T)
            .mean()
            .detach()
            .cpu()
            .item()
            for i in range(len(cands))
        ]
        if self.echecker_model == "none":
            if agg_score == "mean":
                return np.mean(sim_scores)
            elif agg_score == "max":
                return np.max(sim_scores)
            else:
                return sim_scores
        else:
            sim_scores = np.array(sim_scores)
            print("Performing error detection")
            has_error = self.detect_error_sents(cands, self.batch_size)
            penalized_scores = sim_scores * (1 - self.penalty * has_error)
            if agg_score == "mean":
                return np.mean(penalized_scores)
            elif agg_score == "max":
                return np.max(penalized_scores)
            else:
                return penalized_scores

    def sentence_score(self, cand, refs, return_error_prob=False):
        emb_cand = self.encode_sent_sbert(cand)
        emb_refs = self.encode_sents_sbert(refs, self.batch_size)
        scores = emb_cand @ emb_refs.T

        if self.echecker_model == "none":
            return scores.mean().detach().cpu().item()
        else:
            score = scores.mean().detach().cpu().item()
            if not return_error_prob:
                has_error = self.detect_error_sent(cand)
                penalized_score = (1 - self.penalty) * score if has_error else score
                return penalized_score
            else:
                has_error, error_prob = self.detect_error_sent(cand, return_error_prob)
                penalized_score = (1 - self.penalty) * score if has_error else score
                return score, error_prob, penalized_score


if __name__ == "__main__":
    pred_cap = "A person is pounding the pavement in their footwear."
    ref_cap = [
        "a person brushing their teeth while getting faster at the end",
        "a person is brushing their teeth while brushing faster towards the end",
        "a person uses a toothbrush to brush their teeth",
        "someone is brushing their teeth loudly and very close by",
        "someone very close by is brushing their teeth loudly",
    ]

    evaluator = Evaluator(
        device="cpu",
        sbert_model="paraphrase-MiniLM-L6-v2",
        echecker_model="echecker_clotho_audiocaps_base",
    )

    has_error, error_prob = evaluator.detect_error_sent(pred_cap, True)
    print("{}, has_error:{}, error_prob:{}".format(pred_cap, has_error, error_prob))

    # score, error_prob, penalized_score = evaluator.sentence_score(pred_cap, ref_cap, return_error_prob=True)
    # print("score:{}, error_prob:{}, penalized_score:{}".format(score, error_prob, penalized_score))
