import os
import json
from functools import partial

import torch

from tqdm import tqdm
from scipy.spatial.distance import cosine
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel

from utils.eval_utils import seq_rep_n, tok_repeat_l, diversity, finalize


class LengthScorer:

    def __init__(self, requires_detokenize=False):
        #self.tokenizer = tiktoken.get_encoding("gpt2")
        #self.eos_idx = self.tokenizer
        self.requires_detokenize =requires_detokenize
    
    def run(self, token_ids, tokenizer):
        token_ids = finalize(token_ids, eos_idx=tokenizer.eos_token_id, bos_idx=tokenizer.bos_token_id)
        
        if self.requires_detokenize:
            text = tokenizer.decode(token_ids)
            tokens = text.split()
        else:
            tokens = token_ids
        
        score = len(tokens)
        return score
    
    def run_batch(self, batched_token_ids, tokenizer):
        return torch.tensor([self.run(x.tolist(), tokenizer) for x in batched_token_ids])

class SeqRepScorer:

    def __init__(self, n, requires_detokenize=False):
        self.n = n
        self.requires_detokenize = requires_detokenize
    
    def run(self, token_ids, tokenizer):
        
        token_ids = finalize(token_ids, eos_idx=tokenizer.eos_token_id, bos_idx=tokenizer.bos_token_id)
        
        if self.requires_detokenize:
            text = tokenizer.decode(token_ids)
            tokens = text.split()
        else:
            tokens = token_ids
        
        # wrap into one-element list
        #print(tokens)
        if len(tokens) < self.n:
            return 0

        score = seq_rep_n([tokens], self.n)
        
        return score
    
    def run_batch(self, batched_token_ids, tokenizer):
        return torch.tensor([self.run(x.tolist(), tokenizer) for x in batched_token_ids])

class TokRepScorer:

    def __init__(self, l, requires_detokenize=False):
        self.l = l
        self.requires_detokenize = requires_detokenize
    
    def run(self, token_ids, tokenizer):

        token_ids = finalize(token_ids, eos_idx=tokenizer.eos_token_id, bos_idx=tokenizer.bos_token_id)

        if self.requires_detokenize:
            text = tokenizer.decode(token_ids)
            tokens = text.split()
            # build dict
            t2i = {} 
            for tok in tokens:
                if tok not in t2i:
                    t2i[tok] = len(t2i)
            tokens = [t2i[t] for t in tokens]
        else:
            tokens = token_ids
        
        
        score = tok_repeat_l(tokens, context_len=self.l)

        return score

    def run_batch(self, batched_token_ids, tokenizer):
        return torch.tensor([self.run(x.tolist(), tokenizer) for x in batched_token_ids])

class DiversityScorer:

    def __init__(self, requires_detokenize=False):

        self.requires_detokenize = requires_detokenize
    
    def run(self, token_ids, tokenizer):

        token_ids = finalize(token_ids, eos_idx=tokenizer.eos_token_id, bos_idx=tokenizer.bos_token_id)

        if self.requires_detokenize:
            text = tokenizer.decode(token_ids)
            tokens = text.split()
        else:
            tokens = token_ids
        #print(tokens)

        if len(tokens) < 4:
            return 0

        score = diversity([tokens])
        
        return score
    
    def run_batch(self, batched_token_ids, tokenizer):
        return torch.tensor([self.run(x.tolist(), tokenizer) for x in batched_token_ids])
        

class CoherenceScorer:

    def __init__(self, prefix_len, max_length, simcse_model_path, device, batch_size=512, dtype="float16"):
        
        
        self.coh_tokenizer = AutoTokenizer.from_pretrained(simcse_model_path)
        self.model = AutoModel.from_pretrained(simcse_model_path).to(device)
        if dtype == "float16":
            self.model.half()
        self.device = device
        self.prefix_len = prefix_len
        self.max_length = max_length
        self.batch_size = batch_size
        self.ctx = torch.amp.autocast(device_type="cuda", dtype=torch.float16 if dtype == "float16" else torch.float32)
    
    
    def run(self, token_ids, tokenizer):

        token_ids = finalize(token_ids, eos_idx=tokenizer.eos_token_id, bos_idx=tokenizer.bos_token_id)
        prefix_ids, cmpl_ids = token_ids[:self.prefix_len], token_ids[self.prefix_len:]

        prefix = tokenizer.decode(prefix_ids)
        completion = tokenizer.decode(cmpl_ids)

        #print(prefix)
        #print(completion)
        inputs = self.coh_tokenizer([prefix, completion], padding=True, truncation=True, return_tensors="pt", max_length=self.max_length).to(self.device)

        with torch.no_grad():
            with self.ctx:
                embeddings = self.model(**inputs, output_hidden_states=True, return_dict=True).pooler_output.cpu()

        cosine_sim = 1.0 - cosine(embeddings[0], embeddings[1])
        #print(cosine_sim)
        
        return cosine_sim
    
    def run_batch(self, batched_token_ids, tokenizer):
        inputs = []
        for token_ids in batched_token_ids:
            token_ids = token_ids.tolist()
            token_ids = finalize(token_ids, eos_idx=tokenizer.eos_token_id, bos_idx=tokenizer.bos_token_id)
            
            prefix_ids, cmpl_ids = token_ids[:self.prefix_len], token_ids[self.prefix_len:]

            prefix = tokenizer.decode(prefix_ids)
            completion = tokenizer.decode(cmpl_ids)
            #print(prefix)
            #print(completion)
            inputs.append(prefix)
            inputs.append(completion)
        
        if self.batch_size != 0:
            embeddings = []
            for i in range(0, len(inputs), self.batch_size):
                batched_inputs = self.coh_tokenizer(inputs[i:i+self.batch_size], padding=True, truncation=True, return_tensors="pt", max_length=self.max_length).to(self.device)

                with torch.no_grad():
                    with self.ctx:
                        batched_embeddings = self.model(**batched_inputs, output_hidden_states=True, return_dict=True).pooler_output.cpu()

                embeddings.append(batched_embeddings)

            embeddings = torch.cat(embeddings, dim=0)
        else:
            inputs = self.coh_tokenizer(inputs, padding=True, truncation=True, return_tensors="pt", max_length=self.max_length).to(self.device)

            with torch.no_grad():
                embeddings = self.model(**inputs, output_hidden_states=True, return_dict=True).pooler_output.cpu()


        embeddings = embeddings.view(len(batched_token_ids), 2, -1).float()
        #print(embeddings)
        return torch.nn.CosineSimilarity(dim=-1, eps=1e-6)(embeddings[:,0], embeddings[:,1])

        
class InformationScorer:

    def __init__(self, max_length, lm_model_path, device, batch_size=32, dtype="float16"):
        self.model = AutoModelForCausalLM.from_pretrained(lm_model_path).to(device)
        self.info_tokenizer = AutoTokenizer.from_pretrained(lm_model_path)
        if dtype == "float16":
            self.model.half()
        self.device = device
        self.max_length = max_length
        self.batch_size = batch_size
        self.ctx = torch.amp.autocast(device_type="cuda", dtype=torch.float16 if dtype == "float16" else torch.float32)
    
    def run(self, token_ids, tokenizer):

        token_ids = finalize(token_ids, eos_idx=tokenizer.eos_token_id, bos_idx=tokenizer.bos_token_id)
        text = tokenizer.decode(token_ids)

        token_ids = [self.info_tokenizer.bos_token_id] + self.info_tokenizer.encode(text) + [self.info_tokenizer.eos_token_id]

        # Default to use bpe, so no need to retokenize
        inputs = torch.LongTensor([token_ids[:-1]]).to(self.device)
        attention_mask = torch.ones_like(inputs).bool().to(self.device)
        labels = torch.LongTensor([token_ids[1:]]).to(self.device)
        labels = labels.masked_fill(labels.eq(self.info_tokenizer.eos_token_id), -100)

        #print(inputs)
        #print(labels)

        with torch.no_grad():
            with self.ctx:
                logits = self.model(inputs, attention_mask=attention_mask).logits

        loss_per_pos = torch.nn.CrossEntropyLoss(reduction="none")(logits.view(-1, logits.size(-1)), labels.view(-1)).view(-1, logits.size(1))
        avg_loss_per_seq = loss_per_pos.sum(-1) / labels.ne(-100).to(loss_per_pos.dtype).sum(-1)

        return avg_loss_per_seq.item()
    
    def run_batch(self, batched_token_ids, tokenizer):
        # Default: every sequence in batched_token_ids is start with 50256
        batched_token_ids = [finalize(token_ids, eos_idx=tokenizer.eos_token_id, bos_idx=tokenizer.bos_token_id) for token_ids in batched_token_ids]
        batched_texts = [tokenizer.decode(token_ids) for token_ids in batched_token_ids]

        batched_token_ids = [self.info_tokenizer.encode(text) for text in batched_texts]
        max_length = max([len(x) for x in batched_token_ids])
        batched_token_ids = torch.LongTensor([[self.info_tokenizer.bos_token_id] + token_ids + [self.info_tokenizer.eos_token_id] * (max_length - len(token_ids) + 1) for token_ids in batched_token_ids])

        
        inputs = batched_token_ids[..., :-1].to(self.device).contiguous()
        attention_mask = torch.ones_like(inputs).bool().to(self.device)
        labels = batched_token_ids[..., 1:].to(self.device).contiguous()
        labels = labels.masked_fill(labels.eq(self.info_tokenizer.eos_token_id), -100)
        #print(inputs)
        #print(labels)

        if self.batch_size != 0:
            logits = []
            for _i, _a in zip(inputs.split(self.batch_size, dim=0), attention_mask.split(self.batch_size, dim=0)):
                with torch.no_grad():
                    with self.ctx:
                        _l = self.model(_i, attention_mask=_a).logits
                logits.append(_l)
            logits = torch.cat(logits, dim=0)
        else:
            logits = self.model(inputs, attention_mask=attention_mask).logits

        loss_per_pos = torch.nn.CrossEntropyLoss(reduction="none")(logits.view(-1, logits.size(-1)), labels.view(-1)).view(-1, logits.size(1))
        avg_loss_per_seq = loss_per_pos.sum(-1) / labels.ne(-100).to(loss_per_pos.dtype).sum(-1)

        return avg_loss_per_seq.cpu()







class ScorerManager:

    scorer_name_maps = {"seq_rep": SeqRepScorer,  
                        "tok_rep": TokRepScorer, 
                        "coherence": CoherenceScorer, 
                        "diversity": DiversityScorer, 
                        "information": InformationScorer,
                        "length": LengthScorer}

    def __init__(self, config, tokenizer):

        if type(config) == "str":
            config = json.load(open(config, "r"))

        self.config = config 
        self.tokenizer = tokenizer
    
    def get_scorers_with_names(self):
        scorers = []
        names = []

        for name, kwargs in self.config.items():
            assert(type(kwargs) == dict), "args of scorer must be dictionary or None"
            if name.split("/")[0] in self.scorer_name_maps:
                names.append(name)
                scorer_class = self.scorer_name_maps[name.split("/")[0]]
                scorers.append(scorer_class(**kwargs))
            else:
                raise NotImplementedError(f"scorer {name} is not implemented in utils.eval_utils")
            
        return scorers, names
    
    def get_scores(self, scorers, sample):
        # get scores from a list of bpe ids
        scores = torch.zeros(len(scorers))
        for i, scorer in enumerate(scorers):
            scores[i] = scorer.run(sample, self.tokenizer)
        return scores
    
    def get_scores_batch(self, scorers, batched_samples):
        # get scores from a batch of bpe ids in a tensor
        scores = torch.zeros(len(batched_samples), len(scorers))
        for i, scorer in enumerate(scorers):
            scores[:, i] = scorer.run_batch(batched_samples, self.tokenizer)
        return scores
    
    def get_named_scores_data(self, names, scorers, data, batch_size=1, max_length=256, save_res = False):
        # get scores from a list of text data
        data_tokens = [self.tokenizer.encode(x) for x in data]
        data_scores = []
        for i in tqdm(range(0, len(data), batch_size)):
            batched_samples = torch.LongTensor([[self.tokenizer.bos_token_id] + x[:max_length] + [self.tokenizer.eos_token_id] * (max_length - len(x)) for x in data_tokens[i:i+batch_size]])
            scores = self.get_scores_batch(scorers, batched_samples)
            data_scores.append(scores)
        data_scores = torch.cat(data_scores, dim=0)
        data_scores_mean = data_scores.mean(0)

        if "information" in names:
            information_idx = names.index("information")
            data_scores_mean[information_idx] = data_scores_mean[information_idx].exp()
        
        if save_res:
            return dict((n,s) for n,s in zip(names, data_scores_mean.tolist())), \
                [[i] + data_scores[i].tolist() for i in range(data_scores.size(0))]

        return dict((n,s) for n,s in zip(names, data_scores_mean.tolist()))

        
if __name__ == "__main__":
    #import tiktoken
    #tokenizer = tiktoken.get_encoding("gpt2")
    tokenizer = AutoTokenizer.from_pretrained("../models/gpt2-large")
    config = {"seq_rep:4": {"n":4, "requires_detokenize": False}, 
              "tok_rep:8": {"l": 8, "requires_detokenize": False}, 
              "tok_rep:16": {"l": 16, "requires_detokenize": False}, 
              "tok_rep:32": {"l": 32, "requires_detokenize": False}, 
              "coherence": {"prefix_len": 32, "device": "cuda:0"},
              "diversity": {"requires_detokenize": False}}
    manager = ScorerManager(config, tokenizer)
    scorers, names = manager.get_scorers_with_names()
    sample = "Segun and Kemi's introduction programme looked like every other from the start and everything went well until it was the turn of Segun's father to address the gathering. Allowing the father of the future groom to speak seemed the gravest mistake made at the function."
    token_ids = tokenizer.encode(sample)
    scores = manager.get_scores(scorers, token_ids)

    for n, s in zip(names, scores):
        print(n)
        print(s)