# pip install accelerate
from transformers import T5TokenizerFast, T5ForConditionalGeneration
from torch.nn.functional import log_softmax
import torch
from typing import Optional, Dict, Any, List

class GPTScore:
    def __init__(self, model='google/flan-t5-large', device='cuda') -> None:
        self.tokenizer = T5TokenizerFast.from_pretrained(model)
        self.model = T5ForConditionalGeneration.from_pretrained(model)
        self.model.to(device)

    def generate_prompt(self, context: str) -> List[str]:
        return f'Generate factually consistent summary for the following text: {context}'
    
    def score(self, contexts: str, claims: str):
        prompts = [self.generate_prompt(context) for context in contexts]
        input_ids = self.tokenizer(
            prompts, ['\n\nTl;dr: </s>'] * len(prompts),
            add_special_tokens=False,
            max_length=512, truncation='only_first', padding=True,
            return_tensors="pt").input_ids.to("cuda")
        target_inputs = self.tokenizer(claims, max_length=512, truncation=True, padding=True, return_tensors="pt")
        target_ids = target_inputs.input_ids.to("cuda")
        target_attention_mask = target_inputs.attention_mask.to("cuda")
        outputs = self.model(input_ids, labels=target_ids)
        softmax_logits = log_softmax(outputs.logits.detach(), dim=-1)
        log_p = torch.gather(softmax_logits, -1, target_ids.unsqueeze(-1)).squeeze(-1)
        return log_p.masked_fill(target_attention_mask == 0, torch.nan).nanmean(dim=-1)


# document = "Neil Aspin's promotion-chasing hosts have not lost in nine National League matches while Adam Lockwood's side are unbeaten in five.\nGuiseley went ahead on 15 minutes against the run of play when a throw-in found James Hurst who squared to Jake Lawlor to stroke into an empty net.\nGateshead defender Liam Hogan superbly blocked Jordan Preston's effort and Guiseley keeper Jonny Maxted then saved well from Wesley York's shot just before the break.\nThe hosts, who started the second half well, levelled on 62 minutes when a slip by half-time substitute Derek Asamoah let York curl sweetly into the top-right corner from the edge of the box.\nMatch report supplied by the Press Association.\nMatch ends, Gateshead 1, Guiseley 1.\nSecond Half ends, Gateshead 1, Guiseley 1.\nSubstitution, Guiseley. Michael Rankine replaces Jordan Preston.\nSubstitution, Gateshead. Luke Hannant replaces Gus Mafuta.\nGus Mafuta (Gateshead) is shown the yellow card.\nSubstitution, Guiseley. Adam Boyes replaces Jake Cassidy.\nGoal!  Gateshead 1, Guiseley 1. Wes York (Gateshead).\nSubstitution, Guiseley. Derek Asamoah replaces Kevan Hurst.\nSecond Half begins Gateshead 0, Guiseley 1.\nFirst Half ends, Gateshead 0, Guiseley 1.\nSimon Walton (Guiseley) is shown the yellow card.\nGoal!  Gateshead 0, Guiseley 1. Jake Lawlor (Guiseley).\nFirst Half begins.\nLineups are announced and players are warming up."
# summary = "gateshead remain unbeaten in the national league after being held to a draw by guiseley."
# summary2 = "tom remain unbeaten in the national league after being held to a draw by guiseley."


# g = GPTScore()
# print(g.score(document, summary))
# print(g.score(document, summary2))

# model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
# tokenizer = T5TokenizerFast.from_pretrained("google/flan-t5-large")
# inputs = tokenizer("A step by step recipe to make bolognese pasta:", return_tensors="pt")
# outputs = model.generate(**inputs)