import pickle as pkl
from transformers import GPT2LMHeadModel, AutoModelForCausalLM, AutoConfig, LlamaForCausalLM
from transformers import AutoTokenizer
import torch
import os
import re
from typing import List
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from regression.lm_embeddings.embeddings_store import EmbeddingsStore
from regression.lm_embeddings.words_to_transcript import load_pkl
from regression.session_story_configs import SessionStoryConfig
import numpy as np

def remove_non_letters(sentence:str):
  '''
  desc: removes all non letters from the str
  '''
  return re.sub(r'\W+', '', sentence)

def indices_where_in_range(indices, value_range_start, value_range_stop):
   return [arg for (arg, value) in enumerate(indices) if value_range_start <= value < value_range_stop ]

def verification_format(transcript_str):
    s = remove_non_letters(transcript_str)
    return s.upper()
    
def n_previous_words_context(n):
    def _context(words_and_times_transcript, word_index):
        if word_index < n:
            context_words_and_times_transcript = words_and_times_transcript[:word_index+1]
        else:
            context_words_and_times_transcript = words_and_times_transcript[word_index-n:word_index+1]
        return " ".join([x["transcript"] for x in context_words_and_times_transcript])
    return _context

class BuildStoryEmbeddings():
    def __init__(self, lm_tokenizer, lm,
                 context_fn,
                 embeddings_save_folder = "./embeddings",
                 layer = 4,
                 batch_size = 64, 
                 device = "cuda",
                 dataset_loc = "./data"):
        self.tokenizer = lm_tokenizer
        self.context_fn = context_fn
        self.lm = lm
        self.device = device
        self.layer = layer
        self.batch_size = batch_size
        if self.tokenizer.pad_token is None:
            self.tokenizer.add_special_tokens({"pad_token":"<pad>"})
            #Resize the embeddings
            self.lm.resize_token_embeddings(len(self.tokenizer), mean_resizing = False)
            self.tokenizer.padding_side = "left"
        self.lm.to(self.device)
        self.lm.eval()
        self.embedding_store = EmbeddingsStore(embeddings_save_folder)
        self.dataset_loc = dataset_loc

    def make_embeddings(self, story):
        words_and_times_transcript = load_pkl(f"{self.dataset_loc}/words_and_times_transcripts/{story}.pkl")
        contexts = []
        for i in range(len(words_and_times_transcript)):
            word_context = self.context_fn(words_and_times_transcript, i)
            contexts.append(word_context)
        context_tokens = self.tokenizer(contexts, return_tensors = "pt", padding = True).to(self.device)
        dataset = TensorDataset(context_tokens["input_ids"], context_tokens["attention_mask"])
        batch_loader = DataLoader(dataset, batch_size = self.batch_size, shuffle = False, drop_last=False)
        batch_embeddings = []
        for (input_ids, attention_mask) in tqdm(batch_loader):
            with torch.no_grad():
                out = self.lm(input_ids = input_ids, attention_mask = attention_mask, output_hidden_states=True)
                hidden_states = out.hidden_states
                embeddings = hidden_states[self.layer]
                last_token_embeddings = embeddings[:,-1,:]
                batch_embeddings.append(last_token_embeddings.clone().detach().cpu().numpy())
        all_word_embeddings = np.concatenate(batch_embeddings, axis=0)
        self.embedding_store.save_story_data(story, all_word_embeddings, contexts, context_tokens)
        return contexts, context_tokens, all_word_embeddings
    
    def make_all_embeddings(self, stories):
        for story in tqdm(stories):
            self.make_embeddings(story)

    
    
        
        