from transformers import GPT2Model, GPT2Config
import copy
import random
from sklearn.decomposition import PCA
import psutil
import time
import os
import torch
import string
import numpy as np
import h5py
# import numpy.typing as npt

from torch.utils import data
from .trial_data import TrialData
from .subject_data import SubjectData
import pandas as pd
from .context_embeddings import get_embeddings, get_word_embeddings
from types import SimpleNamespace
from typing import Optional, List, Dict, Any, Tuple

# List of all language responsive subjects in the dataset
lang_reponsive_subj_lst = ['m00183', 'm00184', 'm00185', 'm00188', 'm00193', 'm00194', 'm00195']

strip_punc = lambda s: s.translate(str.maketrans('', '', string.punctuation))
clean_word = lambda s: str.lower(strip_punc(s))

def lowercase_and_strip_punc(w):
    return w.replace(".","").lower()

class EmbeddingDataset(data.Dataset):
    def __init__(self, subject_data: SubjectData,
                 args: SimpleNamespace) -> None:
        self.subject_data = subject_data
        word_df = self.get_gpt_contexts(args)
        self.seeg2embed = args.seeg2embed

        if args.target_embedding == "glove":
            #TODO test this branch. The movie words are no longer in a list, but a dataframe.
            word_embedding_path = args.target_embedding_path
            self.word_vecs = self.load_glove_embeds(word_embedding_path, word_df)
        elif args.target_embedding == "glove_contextual_scramble":
            #TODO test this branch. The movie words are no longer in a list, but a dataframe.
            word_embedding_path = args.target_embedding_path
            self.word_vecs = self.load_glove_embeds(word_embedding_path, word_df, contextual=True, scramble=True)
        elif args.target_embedding == "glove_contextual":
            #TODO test this branch. The movie words are no longer in a list, but a dataframe.
            word_embedding_path = args.target_embedding_path
            self.word_vecs = self.load_glove_embeds(word_embedding_path, word_df, contextual=True)
        elif args.target_embedding == "arbitrary":
            self.word_vecs = self.build_random_embeds(word_df) #[n_words, n_dim]
        elif args.target_embedding == "arbitrary_contextual_normal":
            self.word_vecs = self.build_random_embeds(word_df, contextual=True, normal=True) #[n_words, n_dim]
        elif args.target_embedding == "arbitrary_contextual":
            self.word_vecs = self.build_random_embeds(word_df, contextual=True) #[n_words, n_dim]
        elif args.target_embedding == "arbitrary_normal":
            self.word_vecs = self.build_random_embeds(word_df, normal=True) #[n_words, n_dim]
            #NOTE: this will be a different list than the GPT2 words, which may be dropped due to gpt2 tokenization
        elif args.target_embedding == "constant":
            self.word_df = word_df
            random_vecs = self.build_random_embeds(word_df) #[n_words, n_dim]
            random_vecs[:,:30] = 0 
            random_vecs[:,30:40] = 1
            self.word_vecs = random_vecs
        elif args.target_embedding in ["bert", "gpt2", "wte", 
                                       "wpe", "no_wpe", "random_gpt",
                                       "random_wpe", "random_wte",
                                       "arbitrary_normal_tknizer"]:
            subject_embeds, subject_word_dfs = [], []
            for trial in subject_data.trials:
                embed_df = get_embeddings(trial, args.target_embedding,
                                          args.target_embedding_model_dir,
                                          args.target_embedding_save_dir,
                                          sent_context=args.sentence_context,
                                          layer=args.layer)
                word_embed_df = get_word_embeddings(embed_df) #TODO refactor this? maybe get_word_embeddings is a wrapper?
                movie_name = trial.movie_id
                movie_words = subject_data.words[subject_data.words['movie_id']==movie_name] #NOTE the assumption is that each trial corresponds with a unique movie
                # The subject_data word_df index is the source of truth for each example's id
                # The subject data word_df index column /should/
                word_embed_df = word_embed_df.iloc[movie_words.index]

                assert (word_embed_df.index == movie_words.index).all()
                assert (word_embed_df['words'] == movie_words['text']).all()

                movie_words = movie_words[~word_embed_df['no_sent_cont']]# remove everything that doesn't have a sentence context
                word_embed_df = word_embed_df[~word_embed_df['no_sent_cont']]

                word_embeds = np.array(word_embed_df['embeds'].tolist()).squeeze() #[n_words, 1,1, dim] -> [n_words, dim]
                subject_embeds.append(word_embeds)
                subject_word_dfs.append(movie_words)
            word_vecs = np.concatenate(subject_embeds)
            print(word_vecs.shape)
            if args.gpt_reduce_dim:
                pca = PCA(n_components=50) #TODO hardcode #TODO should this only be fit to train data? should this be done per unique word or across all instances?
                word_vecs = pca.fit_transform(word_vecs)
            self.word_vecs = word_vecs
            self.word_df = pd.concat(subject_word_dfs)
            print(self.word_df.to_csv('filtered_dataframe.csv'))
            self.word_df = self.word_df.reset_index(drop=True)#TODO is resetting the index the right thing to do here
            assert self.word_vecs.shape[0] == len(self.word_df)

    def get_gpt_contexts(self, args):
        # The idea behind this method is to walk through all the trials and find the words that there aren't gpt2 embeddings for. We want to make sure that these words are also excluded from all trials. #TODO duplicate code here that should be removed
        subject_embeds, subject_word_dfs = [], []
        for trial in self.subject_data.trials:
            tmp_args = copy.deepcopy(args)
            tmp_args.target_embedding = 'gpt2'
            tmp_args.layer = 8
            tmp_args.sentence_context = True
            embed_df = get_embeddings(trial, tmp_args.target_embedding,
                                      tmp_args.target_embedding_model_dir,
                                      tmp_args.target_embedding_save_dir,
                                      sent_context=tmp_args.sentence_context,
                                      layer=tmp_args.layer)
            word_embed_df = get_word_embeddings(embed_df) #TODO refactor this? maybe get_word_embeddings is a wrapper?
            movie_name = trial.movie_id
            movie_words = self.subject_data.words[self.subject_data.words['movie_id']==movie_name] #NOTE the assumption is that each trial corresponds with a unique movie
            movie_words['context'] = embed_df['context'] #NOTE: assume that nothing has been lost in getting the word embeddings
            # The subject_data word_df index is the source of truth for each example's id
            # The subject data word_df index column /should/
            word_embed_df = word_embed_df.iloc[movie_words.index]

            assert (word_embed_df.index == movie_words.index).all()
            assert (word_embed_df['words'] == movie_words['text']).all()

            movie_words = movie_words[~word_embed_df['no_sent_cont']]# remove everything that doesn't have a sentence context
            subject_word_dfs.append(movie_words)
        word_df = pd.concat(subject_word_dfs)
        return word_df

    def build_random_embeds(self, word_df, normal=False, contextual=False):
        '''
            NOTE: we set the word_df in here. Not a pure function
        '''
        word_df['text'] = list(map(clean_word, word_df['text']))
        self.word_df = word_df
        random_embed_dict = self.build_random_embed_dict(word_df)

        n_dim = 50 #TODO hardcode
        n_words = len(random_embed_dict)
        word_vecs = np.random.rand(n_words, n_dim)*2 - 1 #uniformly in [1,-1]
        words = word_df.text
        idxs = [random_embed_dict[w] for w in words]

        if not normal and not contextual:
            return word_vecs[idxs]
        
        if normal:
            word_vecs = torch.nn.Embedding(50257, 768) #Note from openai
            word_vecs.weight.data.normal_(mean=0.0, std=0.02) #NOTE 0.02 comes from huggingface
            word_vecs = word_vecs.weight.detach().numpy()
            word_vecs = word_vecs[:n_words, :]

            word_vecs = word_vecs[idxs]
            pca = PCA(n_components=50) #TODO hardcode #TODO should this only be fit to train data? 
            word_vecs = pca.fit_transform(word_vecs)

        if contextual:
            word_embedding = {w:word_vecs[random_embed_dict[w]] for w in word_df.text}
            #import pdb; pdb.set_trace()
            word_vecs = self.make_arbitrary_contextual(word_embedding)
        return word_vecs

    def build_random_embed_dict(self, word_df):
        uniq = set(list(word_df.text))
        return {u:i for (i,u) in enumerate(uniq)}

    def make_arbitrary_contextual(self, word_embedding):
        clean_word_list = lambda l: list(map(clean_word, l))
        zero_vector = np.zeros(50) #TODO hardcode
        self.word_df['context'] = list(map(clean_word_list, self.word_df['context']))
        word_vecs_list = []

        positions = torch.nn.Embedding(1024, 50) #NOTE following openai initialization
        positions.weight.data.normal_(mean=0.0, std=0.02) #NOTE 0.02 comes from huggingface

        positions = positions.weight.detach().numpy()

        for context in self.word_df['context']:
            embeds = [word_embedding.get(w, zero_vector) for w in context]
            embeds = np.stack(embeds)
            context_positions = positions[:embeds.shape[0]]
            embeds = np.multiply(context_positions, embeds)
            avg_embed = np.mean(embeds,axis=0)
            word_vecs_list.append(avg_embed)
        context_word_vecs = np.stack(word_vecs_list)
        return context_word_vecs

    def load_glove_embeds(self, word_embedding_path: str, word_df, contextual=False, scramble=False):
        word_embedding = self.load_static_embeddings(word_embedding_path)
        word_vecs_list = []
        words_df = word_df.copy()
        words_df['text'] = list(map(clean_word, words_df['text']))
        self.word_df = words_df
        #self.word_df = words_df[words_df.text.isin(word_embedding)].reset_index() #TODO is resetting the index the right thing to do here?
        zero_vector = np.zeros(50) #TODO hardcode
        word_vecs_list = [word_embedding.get(w, zero_vector) for w in self.word_df.text.to_list()]
        word_vecs = np.array(word_vecs_list)
        assert word_vecs.shape[1] == 50
        if scramble:
            word_embedding_values = list(word_embedding.values())
            word_embedding_keys = list(word_embedding.keys())
            random.shuffle(word_embedding_keys)
            word_embedding = {k:v for k,v in zip(word_embedding_keys, word_embedding_values)}
        if contextual:
            word_vecs = self.make_arbitrary_contextual(word_embedding)
        return word_vecs

    def load_static_embeddings(self, path: str) -> Dict[str, np.ndarray]:
        '''
        Input:
            path = path to a formatted embedding file where
                   each line is a word followed by the vector embedding
        Returns:
            embeddings_dict = dictionary mapping word to numpy 
                              embedding vector
        '''
        with open(path, 'r') as f:
            embeddings_dict = {}
            lines = f.readlines()
            for line in lines:
                values = line.split()
                word = values[0]
                vector = np.asarray(values[1:], "float32")
                embeddings_dict[word] = vector
        return embeddings_dict

    def __len__(self):
        '''
            returns:
                Number of words in the dataset
        '''
        return len(self.word_df)

    def __getitem__(self, idx: int):
        ecog_idx = self.word_df['ecog_idx'].iloc[idx] #the subject_data object keeps a mapping from word in the transcript to array index into the filtered data array. This is that index.
        ecog_data = torch.FloatTensor(self.subject_data.neural_data[:,ecog_idx,:]) 
        target_embedding  = torch.FloatTensor(self.word_vecs[idx])

        #NOTE: remember not to load to cuda here
        if self.seeg2embed:
            source = ecog_data #has shape [n_electrodes, n_samples]
            target = target_embedding#has shape [embedding_dim]
        else:#embed2seeg
            source = target_embedding.unsqueeze(0)
            target = ecog_data.squeeze(0) #NOTE: assumes only n_electrodes=1 or some aggregation has been used
            #import pdb; pdb.set_trace()
        return {"source" : source,
                "target" : target, 
                "pos": self.word_df['pos'].iloc[idx], #We use iloc, but it shouldn't make a difference if we use loc/iloc since we reset the index in the init
                "word": self.word_df['text'].iloc[idx],
                }

    def get_movie_ids(self) -> List[str]:
        return self.word_df['movie_id']

    def get_input_dim(self):
        if self.seeg2embed:
            return self.get_seeg_shape()[1]
        else:
            return self.get_embedding_dim()

    def get_output_dim(self):
        if self.seeg2embed:
            return self.get_embedding_dim()
        else:
            return self.get_seeg_shape()[1]

    def get_embedding_dim(self) -> int:
        return self.word_vecs.shape[1]

    def get_seeg_shape(self) -> Tuple[int, int]:
        return self.subject_data.neural_data.shape[0], self.subject_data.neural_data.shape[2]
