"""Convert QA into claims using the model.
"""

from typing import Text, Dict, List, Any, Tuple
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    T5ForConditionalGeneration,
    T5TokenizerFast,
)
import numpy as np
import torch
from .utils.helpers import batchify


def get_declarativizer_model() -> Tuple[T5ForConditionalGeneration, T5TokenizerFast]:
    """
    """
    model = AutoModelForSeq2SeqLM.from_pretrained("domenicrosati/question_converter-3b")
    tokenizer = AutoTokenizer.from_pretrained("domenicrosati/question_converter-3b")
    
    return model, tokenizer


def convert(
    inputs: List[Dict[Text, Text]],
    model: T5ForConditionalGeneration,
    tokenizer: T5TokenizerFast,
    device: torch.device
) -> List[Text]:
    """
    """
    _prepare_input_str = lambda x: f"{x['question']} </s> {x['answer']}"
    input_strs = [_prepare_input_str(input_dict) for input_dict in inputs]

    sorted_ids = np.argsort([len(x) for x in input_strs])
    sorted_input_strs = [input_strs[i] for i in sorted_ids]

    _prepare_input = lambda batch: {k: v.to(device) for k, v in tokenizer(batch, padding=True, truncation=True, max_length=512, return_tensors="pt").items()}

    sorted_generations = []

    for batch in batchify(sorted_input_strs, batch_size=16):
        batch_inputs = _prepare_input(batch)
        outputs = model.generate(**batch_inputs, max_length=512)
        gens = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        sorted_generations.extend(gens)
        
    unsorted_generations = [sorted_generations[i] for i in np.argsort(sorted_ids)]

    return unsorted_generations


def declarativize(
    instances: List[Dict[Text, Any]],
    model: T5ForConditionalGeneration,
    tokenizer: T5TokenizerFast,
    device: torch.device
) -> List[Dict[Text, Any]]:
    """The reason for this preparation is for the easy of
    parallel decoding.
    """
    
    lengths = [len(instance['clusters']) for instance in instances]
    input_dicts = [
        {
            "question": instance['question'],
            "answer": min(cluster['sentences'], key=len)
        } for instance in instances for cluster in instance['clusters']
    ]

    with torch.no_grad():
        generations = convert(input_dicts, model, tokenizer, device)
    start_pos = 0
    for length, instance in zip(lengths, instances):
        assert len(instance['clusters']) == length, "The number of clusters should be the same as the length of the input instances."

        for cluster, claim in zip(instance['clusters'], generations[start_pos:start_pos+length]):
            cluster['claim'] = claim
            
        start_pos += length
        
    return instances