from logging import warning
import math
import spacy
from nltk.tokenize import sent_tokenize
import torch
from model.bert import BERTAlignModel
from transformers import AdamW, AutoConfig, AutoTokenizer, BartForConditionalGeneration
from transformers.models.bart.modeling_bart import shift_tokens_right
import torch.nn as nn
from tqdm import tqdm

class Inferencer():
    def __init__(self, ckpt_path, model='bert-base-uncased', batch_size=32, device='cuda', verbose=True) -> None:
        self.device = device
        if ckpt_path is not None:
            self.model = BERTAlignModel(model=model).load_from_checkpoint(checkpoint_path=ckpt_path, strict=False).to(self.device)
        else:
            warning('loading UNTRAINED model!')
            self.model = BERTAlignModel(model=model).to(self.device)
        self.model.eval()
        self.batch_size = batch_size

        self.config = AutoConfig.from_pretrained(model)
        self.tokenizer = AutoTokenizer.from_pretrained(model)
        self.spacy = spacy.load('en_core_web_sm')

        self.loss_fct = nn.CrossEntropyLoss(reduction='none')
        self.softmax = nn.Softmax(dim=-1)

        self.smart_type = 'smart-n'
        self.smart_n_metric = 'f1'

        self.disable_progress_bar_in_inference = False

        self.nlg_eval_mode = None # bin, bin_sp, nli, nli_sp
        self.verbose = verbose
    
    def inference_example_batch(self, premise: list, hypo: list):
        """
        inference a example,
        premise: list
        hypo: list
        using self.inference to batch the process

        SummaC Style aggregation
        """
        self.disable_progress_bar_in_inference = True
        assert len(premise) == len(hypo), "Premise must has the same length with Hypothesis!"

        out_score = []
        for one_pre, one_hypo in tqdm(zip(premise, hypo), desc="Evaluating", total=len(premise), disable=(not self.verbose)):
            out_score.append(self.inference_per_example(one_pre, one_hypo))
        
        return None, torch.tensor(out_score), None

    def inference_per_example(self, premise:str, hypo: str):
        """
        inference a example,
        premise: string
        hypo: string
        using self.inference to batch the process
        """
        def chunks(lst, n):
            """Yield successive n-sized chunks from lst."""
            for i in range(0, len(lst), n):
                yield ' '.join(lst[i:i + n])
        
        premise_sents = sent_tokenize(premise)# [each.text for each in self.spacy(premise).sents]
        # premise_sents = [' '.join(premise_sents[:int(len(premise_sents)/2)]), ' '.join(premise_sents[int(len(premise_sents)/2):])]
        # premise_sents = [' '.join(premise_sents[:int(len(premise_sents)/2)]), ' '.join(premise_sents[int(len(premise_sents)/2):])] if len(premise_sents) > 2 else [premise]
        premise_sents = premise_sents or ['']

        n_chunk = len(premise.strip().split()) // 350 + 1
        n_chunk = max(len(premise_sents) // n_chunk, 1)
        premise_sents = [each for each in chunks(premise_sents, n_chunk)]

        hypo_sents = sent_tokenize(hypo) #[each.text for each in self.spacy(hypo).sents]
        # premise_sents = [premise_sents[i]+" "+premise_sents[i+1] for i in range(len(premise_sents)-1)] if len(premise_sents) > 1 else premise_sents

        premise_sent_mat = []
        hypo_sents_mat = []
        for i in range(len(premise_sents)):
            for j in range(len(hypo_sents)):
                premise_sent_mat.append(premise_sents[i])
                hypo_sents_mat.append(hypo_sents[j])
        
        if self.nlg_eval_mode is not None:
            if self.nlg_eval_mode == 'nli_sp':
                output_score = self.inference(premise_sent_mat, hypo_sents_mat)[2][:,0] ### use NLI head OR ALIGN head
            elif self.nlg_eval_mode == 'bin_sp':
                output_score = self.inference(premise_sent_mat, hypo_sents_mat)[1] ### use NLI head OR ALIGN head
            elif self.nlg_eval_mode == 'reg_sp':
                output_score = self.inference(premise_sent_mat, hypo_sents_mat)[0] ### use NLI head OR ALIGN head
            
            output_score = output_score.view(len(premise_sents), len(hypo_sents)).max(dim=0).values.mean().item() ### sum or mean depends on the task/aspect
            return output_score

        
        output_score = self.inference(premise_sent_mat, hypo_sents_mat)[2][:,0] ### use NLI head OR ALIGN head
        output_score = output_score.view(len(premise_sents), len(hypo_sents)).max(dim=0).values.mean().item() ### sum or mean depends on the task/aspect

        return output_score

        # logits = output.view(-1, self.config.vocab_size)
        # loss = self.loss_fct(self.lsm(logits), tgt_tokens.view(-1))
        # loss = loss.view(tgt_tokens.shape[0], -1)
        # loss = loss.sum(dim=1) / tgt_len
        # curr_score_list = [-x.item() for x in loss]



    def inference(self, premise, hypo):
        """
        inference a list of premise and hypo

        Standard aggregation
        """
        if isinstance(premise, str) and isinstance(hypo, str):
            premise = [premise]
            hypo = [hypo]
        
        batch = self.batch_tokenize(premise, hypo)
        output_score_reg = []
        output_score_bin = []
        output_score_tri = []

        for mini_batch in tqdm(batch, desc="Evaluating", disable=not self.verbose or self.disable_progress_bar_in_inference):
        # for mini_batch in batch:
            mini_batch = mini_batch.to(self.device)
            with torch.no_grad():
                model_output = self.model(mini_batch)
                model_output_reg = model_output.reg_label_logits.cpu()
                model_output_bin = model_output.seq_relationship_logits # Temperature Scaling / 2.5
                model_output_tri = model_output.tri_label_logits
                
                model_output_bin = self.softmax(model_output_bin).cpu()
                model_output_tri = self.softmax(model_output_tri).cpu()
            output_score_reg.append(model_output_reg[:,0])
            output_score_bin.append(model_output_bin[:,1])
            output_score_tri.append(model_output_tri[:,:])
        
        output_score_reg = torch.cat(output_score_reg)
        output_score_bin = torch.cat(output_score_bin)
        output_score_tri = torch.cat(output_score_tri)
        
        if self.nlg_eval_mode is not None:
            if self.nlg_eval_mode == 'nli':
                output_score_nli = output_score_tri[:,0]
                return None, output_score_nli, None
            elif self.nlg_eval_mode == 'bin':
                return None, output_score_bin, None
            elif self.nlg_eval_mode == 'reg':
                return None, output_score_reg, None
            else:
                ValueError("unrecognized nlg eval mode")

        
        return output_score_reg, output_score_bin, output_score_tri
    
    def inference_reg(self, premise, hypo):
        """
        inference a list of premise and hypo

        Standard aggregation
        """
        self.model.is_reg_finetune = True
        if isinstance(premise, str) and isinstance(hypo, str):
            premise = [premise]
            hypo = [hypo]
        
        batch = self.batch_tokenize(premise, hypo)
        output_score = []

        for mini_batch in tqdm(batch, desc="Evaluating", disable=self.disable_progress_bar_in_inference):
        # for mini_batch in batch:
            mini_batch = mini_batch.to(self.device)
            with torch.no_grad():
                model_output = self.model(mini_batch).seq_relationship_logits.cpu().view(-1)
            output_score.append(model_output)
        output_score = torch.cat(output_score)
        return output_score
    
    def batch_tokenize(self, premise, hypo):
        """
        input premise and hypos are lists
        """
        assert isinstance(premise, list) and isinstance(hypo, list)
        assert len(premise) == len(hypo), "premise and hypo should be in the same length."

        batch = []
        for mini_batch_pre, mini_batch_hypo in zip(self.chunks(premise, self.batch_size), self.chunks(hypo, self.batch_size)):
            try:
                mini_batch = self.tokenizer(mini_batch_pre, mini_batch_hypo, truncation='only_first', padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
            except:
                warning('text_b too long...')
                mini_batch = self.tokenizer(mini_batch_pre, mini_batch_hypo, truncation=True, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
            batch.append(mini_batch)

        return batch
    def smart_doc(self, premise: list, hypo: list):
        """
        inference a example,
        premise: list
        hypo: list
        using self.inference to batch the process

        SMART Style aggregation
        """
        self.disable_progress_bar_in_inference = True
        assert len(premise) == len(hypo), "Premise must has the same length with Hypothesis!"
        assert self.smart_type in ['smart-n', 'smart-l']

        out_score = []
        for one_pre, one_hypo in tqdm(zip(premise, hypo), desc="Evaluating SMART", total=len(premise)):
            out_score.append(self.smart_l(one_pre, one_hypo)[1] if self.smart_type == 'smart-l' else self.smart_n(one_pre, one_hypo)[1])
        
        return None, torch.tensor(out_score), None

    def smart_l(self, premise, hypo):
        premise_sents = [each.text for each in self.spacy(premise).sents]
        hypo_sents = [each.text for each in self.spacy(hypo).sents]
        # premise_sents = [premise_sents[i]+" "+premise_sents[i+1] for i in range(len(premise_sents)-1)] if len(premise_sents) > 1 else premise_sents

        premise_sent_mat = []
        hypo_sents_mat = []
        for i in range(len(premise_sents)):
            for j in range(len(hypo_sents)):
                premise_sent_mat.append(premise_sents[i])
                hypo_sents_mat.append(hypo_sents[j])
        
        output_score = self.inference(premise_sent_mat, hypo_sents_mat)[2][:,0]
        output_score = output_score.view(len(premise_sents), len(hypo_sents))

        ### smart-l
        lcs = [[0] * (len(hypo_sents)+1)] * (len(premise_sents)+1)
        for i in range(len(premise_sents)+1):
            for j in range(len(hypo_sents)+1):
                if i != 0 and j != 0:
                    m = output_score[i-1, j-1]
                    lcs[i][j] = max([lcs[i-1][j-1]+m,
                                        lcs[i-1][j]+m,
                                        lcs[i][j-1]])

        return None, lcs[-1][-1] / len(premise_sents), None
    
    def smart_n(self, premise, hypo):
        ### smart-n
        n_gram = 1

        premise_sents = [each.text for each in self.spacy(premise).sents]
        hypo_sents = [each.text for each in self.spacy(hypo).sents]
        # premise_sents = [premise_sents[i]+" "+premise_sents[i+1] for i in range(len(premise_sents)-1)] if len(premise_sents) > 1 else premise_sents

        premise_sent_mat = []
        hypo_sents_mat = []
        for i in range(len(premise_sents)):
            for j in range(len(hypo_sents)):
                premise_sent_mat.append(premise_sents[i])
                hypo_sents_mat.append(hypo_sents[j])
        
        output_score = self.inference(premise_sent_mat, hypo_sents_mat)[2][:,0]
        output_score = output_score.view(len(premise_sents), len(hypo_sents))
        
        prec = sum([max([sum([output_score[i+n, j+n]/n_gram for n in range(0, n_gram)]) for i in range(len(premise_sents)-n_gram+1)]) for j in range(len(hypo_sents)-n_gram+1)])
        prec = prec / (len(hypo_sents) - n_gram + 1) if (len(hypo_sents) - n_gram + 1) > 0 else 0.


        premise_sents = [each.text for each in self.spacy(hypo).sents]# simple change
        hypo_sents = [each.text for each in self.spacy(premise).sents]#
        # premise_sents = [premise_sents[i]+" "+premise_sents[i+1] for i in range(len(premise_sents)-1)] if len(premise_sents) > 1 else premise_sents

        premise_sent_mat = []
        hypo_sents_mat = []
        for i in range(len(premise_sents)):
            for j in range(len(hypo_sents)):
                premise_sent_mat.append(premise_sents[i])
                hypo_sents_mat.append(hypo_sents[j])
        
        output_score = self.inference(premise_sent_mat, hypo_sents_mat)[2][:,0]
        output_score = output_score.view(len(premise_sents), len(hypo_sents))

        recall = sum([max([sum([output_score[i+n, j+n]/n_gram for n in range(0, n_gram)]) for i in range(len(premise_sents)-n_gram+1)]) for j in range(len(hypo_sents)-n_gram+1)])
        recall = prec / (len(hypo_sents) - n_gram + 1) if (len(hypo_sents) - n_gram + 1) > 0 else 0.

        f1 = 2 * prec * recall / (prec + recall)

        if self.smart_n_metric == 'f1':
            return None, f1, None
        elif self.smart_n_metric == 'precision':
            return None, prec, None
        elif self.smart_n_metric == 'recall':
            return None, recall, None
        else:
            ValueError("SMART return type error")

    def chunks(self, lst, n):
        """Yield successive n-sized chunks from lst."""
        for i in range(0, len(lst), n):
            yield lst[i:i + n]
    
    def nlg_eval(self, premise, hypo):
        assert self.nlg_eval_mode is not None, "Select NLG Eval mode!"
        if (self.nlg_eval_mode == 'bin') or (self.nlg_eval_mode == 'nli') or (self.nlg_eval_mode == 'reg'):
            return self.inference(premise, hypo)
        
        elif (self.nlg_eval_mode == 'bin_sp') or (self.nlg_eval_mode == 'nli_sp') or (self.nlg_eval_mode == 'reg_sp'):
            return self.inference_example_batch(premise, hypo)
        
        else:
            ValueError("Unrecognized NLG Eval mode!")

if __name__ == "__main__":
    import json
    from scipy.stats import pearsonr, kendalltau, spearmanr
    infer = Inferencer(ckpt_path='checkpoints/bert/bert_cnndm_6x2x8-v1.ckpt')
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    """
    inference CTC datasets
    """

    print(infer.inference('LONDON, England (Reuters) -- Harry Potter star Daniel Radcliffe gains access to a reported £20 million ($41.1 million) fortune as he turns 18 on Monday, but he insists the money won\'t cast a spell on him. Daniel Radcliffe as Harry Potter in \\"Harry Potter and the Order of the Phoenix\\" To the disappointment of gossip columnists around the world, the young actor says he has no plans to fritter his cash away on fast cars, drink and celebrity parties. \\"I don\'t plan to be one of those people who, as soon as they turn 18, suddenly buy themselves a massive sports car collection or something similar,\\" he told an Australian interviewer earlier this month. \\"I don\'t think I\'ll be particularly extravagant. \\"The things I like buying are things that cost about 10 pounds -- books and CDs and DVDs.\\" At 18, Radcliffe will be able to gamble in a casino, buy a drink in a pub or see the horror film \\"Hostel: Part II,\\" currently six places below his number one movie on the UK box office chart. Details of how he\'ll mark his landmark birthday are under wraps. His agent and publicist had no comment on his plans. \\"I\'ll definitely have some sort of party,\\" he said in an interview. \\"Hopefully none of you will be reading about it.\\" Radcliffe\'s earnings from the first five Potter films have been held in a trust fund which he has not been able to touch. Despite his growing fame and riches, the actor says he is keeping his feet firmly on the ground. \\"People are always looking to say \'kid star goes off the rails,\'\\" he told reporters last month. \\"But I try very hard not to go that way because it would be too easy for them.\\" His latest outing as the boy wizard in \\"Harry Potter and the Order of the Phoenix\\" is breaking records on both sides of the Atlantic and he will reprise the role in the last two films.  Watch I-Reporter give her review of Potter\'s latest » . There is life beyond Potter, however. The Londoner has filmed a TV movie called \\"My Boy Jack,\\" about author Rudyard Kipling and his son, due for release later this year. He will also appear in \\"December Boys,\\" an Australian film about four boys who escape an orphanage. Earlier this year, he made his stage debut playing a tortured teenager in Peter Shaffer\'s \\"Equus.\\" Meanwhile, he is braced for even closer media scrutiny now that he\'s legally an adult: \\"I just think I\'m going to be more sort of fair game,\\" he told Reuters. E-mail to a friend . Copyright 2007 Reuters. All rights reserved.This material may not be published, broadcast, rewritten, or redistributed.', 
    'Harry Potter star Daniel Radcliffe gets £20M fortune as he turns 18 Monday .\\nYoung actor says he has no plans to fritter his cash away .\\nRadcliffe\'s earnings from first five Potter films have been held in trust fund .'))