"""Using semantic uncertainty to cluster answers
"""

import numpy as np
import torch
from functools import partial
from tqdm import tqdm
# from transformers import pipeline
# from parallelformers import parallelize
from typing import Text, List, Iterable, Dict, Any, Tuple
# Load model directly
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    AutoModelForSequenceClassification,
    T5ForConditionalGeneration,
    T5TokenizerFast,
    DebertaV2ForSequenceClassification,
    DebertaV2TokenizerFast
)
from .utils.helpers import deconstruct, model_tokenizer_aware_hash
from .utils.disjoint_set import find_disjoint_sets
from .utils.helpers import batchify as _batchify


def get_entailment_model() -> Tuple[T5ForConditionalGeneration, T5TokenizerFast]:
    tokenizer = AutoTokenizer.from_pretrained("google/t5_xxl_true_nli_mixture")
    entailment_model = AutoModelForSeq2SeqLM.from_pretrained("google/t5_xxl_true_nli_mixture")
    # running on A100 we don't need to employ parallelization
    # parallelize(entailment_model, num_gpus=2, fp16=False, verbose='detail')
    
    return entailment_model, tokenizer


def get_small_entailment_model() -> Tuple[DebertaV2ForSequenceClassification, DebertaV2TokenizerFast]:
    """
    """
    tokenizer = DebertaV2TokenizerFast.from_pretrained("MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli")
    entailment_model = AutoModelForSequenceClassification.from_pretrained("MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli")

    return entailment_model, tokenizer


def test_entailment(
    pairs: Iterable[Tuple[Text, Text]],
    entailment_model,
    tokenizer,
    device
) -> List[bool]:
    """Run the inference pipeline to get entailment scores for
    each pair of sentences.
    
    1 --- Entailment
    0 --- Not entailed
    """
    
    length_ids = np.argsort([len(pair[0]) + len(pair[1]) for pair in pairs])
    sorted_pairs = [pairs[i] for i in length_ids]
    
    _prepare_inputs = lambda batch: {k: v.to(device) for k, v in tokenizer([f"premise: {pair[0]} hypothesis: {pair[1]}" for pair in batch], return_tensors="pt", padding=True, truncation=True, max_length=256).items()}
    
    # print(list(map(_prepare_inputs, _batchify(pairs, batch_size=1))))
    
    _batch_decode = lambda results: tokenizer.batch_decode(results, skip_special_tokens=True)
    _flatten = lambda l: [item[0] == "1" for sublist in l for item in sublist]

    sorted_is_entailed = _flatten(map(_batch_decode, map(deconstruct(partial(entailment_model.generate, max_length=5)), map(_prepare_inputs, _batchify(sorted_pairs, batch_size=1)))))
    unsorted_is_entailed = [sorted_is_entailed[i] for i in np.argsort(length_ids)]
    
    return unsorted_is_entailed


def test_entailment_small(
    pairs: Iterable[Tuple[Text, Text]],
    entailment_model,
    tokenizer,
    device
) -> List[bool]:
    """
    """
    pairs = list(pairs)
    length_ids = np.argsort([len(pair[0]) + len(pair[1]) for pair in pairs])
    sorted_pairs = [pairs[i] for i in length_ids]

    _prepare_inputs = lambda _batch: {k: v.to(device) for k, v in tokenizer([pair[0] for pair in batch], [pair[1] for pair in _batch], return_tensors="pt", padding=True, truncation=True, max_length=256).items()}

    sorted_is_entailed = []

    def _inference(inputs) -> List[bool]:
        """
        """
        outputs = entailment_model(**inputs)
        prediction = torch.argmax(outputs.logits, dim=1).cpu().numpy()
        return (prediction == 0).tolist()
    
    for batch in _batchify(sorted_pairs, batch_size=256):
        inputs = _prepare_inputs(batch)
        entailments = _inference(inputs)
        sorted_is_entailed.extend(entailments)
        
    unsorted_is_entailed = [sorted_is_entailed[i] for i in np.argsort(length_ids)]
    
    return unsorted_is_entailed


# @task(
#     cache_key_fn=model_tokenizer_aware_hash,
#     version="0.1.1"
# )
def cluster(
    sentences: List[Text],
    entailment_model,
    tokenizer,
    device: torch.device
) -> Tuple[List[Dict[Text, Any]], np.ndarray]:
    """
    """
    
    def _create_pairs(sents):
        for i in range(len(sents)):
            for j in range(len(sents)):
                if i != j:
                    yield (sents[i], sents[j])
                    
    entailments: List[float] = test_entailment_small(_create_pairs(sentences), entailment_model, tokenizer, device=device)
    entailment_score_mat = np.array(entailments).reshape(len(sentences), len(sentences) - 1)
    _base = list(range(len(sentences)))
    indices = np.array([_base[:i] + _base[i + 1:] for i in range(len(sentences))], dtype=np.int8)
    entailment_mat = np.eye(len(sentences), dtype=np.int8)
    
    np.put_along_axis(entailment_mat, indices, entailment_score_mat, axis=1)
    
    # Using this criteria, we put all into the same set if they are entailed in both directions
    criteria = lambda i, j: entailment_mat[i, j] == 1 and entailment_mat[j, i] == 1

    return (
        [
            {
                "cluster_id": index,
                "set_ids": set_ids,
                "sentences": [sentences[i] for i in set_ids]
            } for index, set_ids in enumerate(find_disjoint_sets(len(sentences), criteria))
        ],
        entailment_mat
    )
