import numpy as np
import os
import pickle as pkl
from regression.session_story_configs import SessionStoryConfig
from typing import List
import time
import torch

def load_pkl(file_loc):
    with open(file_loc, "rb") as f:
        return pkl.load(f)

def save_pkl(obj, file_loc):
    with open(file_loc, "wb") as f:
        pkl.dump(obj, f)

class EmbeddingsStore():
    def __init__(self, save_loc, embeddings_save_type = np.float16):
        self.save_loc = save_loc
        self.embeddings_save_type = embeddings_save_type
        os.makedirs(save_loc, exist_ok=True)
    
    def embedding_save_loc(self, story):
        return f"{self.save_loc}/{story}/embeddings.npy"
    
    def context_save_loc(self, story):
        return f"{self.save_loc}/{story}/contexts.pkl"
    
    def tokens_save_loc(self, story):
        return f"{self.save_loc}/{story}/tokens.pkl"
    
    def save_story_data(self, story, embeddings, contexts, tokens):
        os.makedirs(f"{self.save_loc}/{story}", exist_ok = True)
        zero_concat_embeddings = np.concatenate((embeddings, np.zeros((1, embeddings.shape[1]))),axis=0)
        np.save(self.embedding_save_loc(story), zero_concat_embeddings.astype(self.embeddings_save_type))
        save_pkl(contexts, self.context_save_loc(story))
        save_pkl(tokens, self.tokens_save_loc(story))
    
    def load_embeddings(self, story, mmap = False, as_feature_map = False):
        if mmap:
            mmap_mode = "r"
        else:
            mmap_mode = None
        if as_feature_map:
            return np.load(self.embedding_save_loc(story), mmap_mode=mmap_mode)
        return np.load(self.embedding_save_loc(story), mmap_mode=mmap_mode)[:-1]
    
    def load_contexts(self, story):
        return load_pkl(self.context_save_loc(story))
    
    def load_tokens(self, story):
        return load_pkl(self.tokens_save_loc(story))

class MEGFeatureMapStore():
    def __init__(self, save_loc):
        self.save_loc = save_loc
    
    def meg_map_save_loc(self, session_story_config:SessionStoryConfig):
        return f"{self.save_loc}/Moth{session_story_config.session}/{session_story_config.subject}/{session_story_config.session_story}.npy"
    
    def save_session_story_data(self, session_story_config:SessionStoryConfig):
        lag, stretch =  session_story_config.load_alignment()
        meg_timings, _ = session_story_config.load_aligned_downsampled_meg()
        meg_index_map = np.zeros(len(meg_timings), dtype=int) - 1

        story = session_story_config.story
        
        words_and_times_loc = f"{session_story_config.dataset_loc}/words_and_times_transcripts/{story}.pkl"
        words_and_times_transcript = load_pkl(words_and_times_loc)
        
        stretched_start_times = [float(x["start_time"])*stretch for x in words_and_times_transcript]
        for i, start_time in enumerate(stretched_start_times):
            less_than_target_bin = start_time < meg_timings
            if less_than_target_bin.any():
                bin_loc = np.argmax(less_than_target_bin)
                meg_index_map[bin_loc] = i
        os.makedirs(f"{self.save_loc}/Moth{session_story_config.session}/{session_story_config.subject}", exist_ok=True)
        np.save(self.meg_map_save_loc(session_story_config), meg_index_map)
    
    def load_meg_map(self, session_story_config):
        return np.load(self.meg_map_save_loc(session_story_config))

class EmbeddingsFeatureDelayReader():
    '''supports 1-d accessing along the meg index dimension'''
    def __init__(self, session_story_config:SessionStoryConfig,
                 embeddings_store:EmbeddingsStore, 
                 subject_map_store:MEGFeatureMapStore,
                 mmap = True,
                 delays=1,
                 embeddings_affine_scaling = None,
                 embeddings_affine_bias = None,
                 use_cuda = False,
                 fold_tensors = True,
                 pca_weights = None):
        self.pca_weights = pca_weights
  
        self.embeddings_readfile = embeddings_store.load_embeddings(session_story_config.story, mmap = mmap, as_feature_map=True)
        self.subject_map_readfile = subject_map_store.load_meg_map(session_story_config)
        self.words_dim, self.embeddings_dim = self.embeddings_readfile.shape
        if not self.pca_weights is None:
            self.embeddings_dim = self.pca_weights.shape[0]
        #accounts for an extra zero at the end of the embeddings
        self.words_dim = self.words_dim - 1
        self.meg_dim = self.subject_map_readfile.shape[0]
        if fold_tensors:
            self.shape = (self.meg_dim, self.embeddings_dim*delays)
        else:
            self.shape = (self.meg_dim, self.embeddings_dim, delays)
        self.delays = delays
        self.embeddings_affine_scaling = embeddings_affine_scaling
        self.embeddings_affine_bias = embeddings_affine_bias
        self.use_cuda = use_cuda
        self.fold_tensors = fold_tensors
        
    def __len__(self):
        return self.meg_dim
    
    def __getitem__(self, idxs):
        def valid_delay(idx):
            if idx < self.delays - 1:
                valid_delays = idx
            else:
                valid_delays = self.delays - 1
            return valid_delays 
        
        if type(idxs) is int:
            idx = idxs

            idx_valid_delay = valid_delay(idx)
            idx_read_list = [-1]*(self.delays - (idx_valid_delay+1)) + list(range(idx - idx_valid_delay, idx+1))
            embeddings_indices = self.subject_map_readfile[idx_read_list]
            embeddings = self.embeddings_readfile[embeddings_indices].astype(np.float32)
            
            idxs_read_lists_arr = np.array(idxs_read_lists)
            if not self.embeddings_affine_bias is None:
                embeddings[idxs_read_lists_arr] = embeddings[idxs_read_lists_arr] + self.embeddings_affine_bias[None,:]
            if not self.embeddings_affine_scaling is None:
                embeddings[idxs_read_lists_arr] = embeddings[idxs_read_lists_arr]*(self.embeddings_affine_scaling[None,:])
            
            #adds permuatation to try
            embeddings = np.transpose(embeddings, (1, 0))
            if self.fold_tensors:
                embeddings = embeddings.reshape(-1)
            return embeddings

        elif (type(idxs) is list) or (type(idxs) is slice) or (type(idxs) is np.ndarray):
            if type(idxs) is slice:
                start, stop, stride = idxs.indices(self.meg_dim)
                idxs = list(range(start, stop, stride))
            #gets the number of delays that work for each index. If the index is less than the delay size, it will need padding
            idxs_valid_delays = [valid_delay(idx) for idx in idxs]
            #gets the indices to reference and adds padding for indices referenced less than 0
            idxs_read_lists = [[-1]*(self.delays - (idxs_valid_delays[i]+1)) + list(range(idxs[i]-idxs_valid_delays[i],idxs[i]+1)) for i in range(len(idxs))]
            embeddings_indices = self.subject_map_readfile[idxs_read_lists]
            embeddings = self.embeddings_readfile[embeddings_indices]
            idxs_affine_mask = embeddings_indices != -1
            if not self.use_cuda:
                if not self.embeddings_affine_bias is None:
                    embeddings[idxs_affine_mask] += self.embeddings_affine_bias[None,:]
                if not self.embeddings_affine_scaling is None:
                    embeddings[idxs_affine_mask] *= self.embeddings_affine_scaling[None,:]
                if not self.pca_weights is None:
                    embeddings = np.einsum("ke,bte -> btk", self.pca_weights,embeddings)
                embeddings = np.ascontiguousarray(embeddings)
                embeddings = np.transpose(embeddings, (0, 2, 1))
                if self.fold_tensors:
                    embeddings = embeddings.reshape(embeddings.shape[0], -1)
                return embeddings
            else:
                #use cuda for faster adding of bias/scale and reshaping
                idxs_affine_mask_t = torch.from_numpy(idxs_affine_mask).cuda(non_blocking=True)
                embeddings_t = torch.from_numpy(embeddings).cuda(non_blocking=True)
                if not self.embeddings_affine_bias is None:
                    bias_t = torch.from_numpy(self.embeddings_affine_bias[None,:]).cuda(non_blocking=True)
                    embeddings_t[idxs_affine_mask_t] += bias_t
                if not self.embeddings_affine_scaling is None:
                    scaling_t = torch.from_numpy(self.embeddings_affine_scaling[None,:]).cuda(non_blocking=True)
                    embeddings_t[idxs_affine_mask_t] *= scaling_t
                if not self.pca_weights is None:
                    pca_weights_t = torch.from_numpy(self.pca_weights).cuda(non_blocking=True)
                    embeddings_t = torch.einsum("ke,bte -> btk", pca_weights_t,embeddings_t)
                embeddings_t = embeddings_t.permute((0,2,1))
                if self.fold_tensors:
                    embeddings_t = embeddings_t.reshape(embeddings_t.shape[0],-1)
                return embeddings_t.cpu().numpy()
        else:
            raise TypeError("Type Not Implemented for Embeddings Reader")

class SessionStoryEmbeddingsFeatureLoader():
    def __init__(self, meg_feature_map_loc, embedding_store_loc,
                 delays = 40, mmap = True, use_cuda = True, fold_tensors = True):
        self.embeddings_store = EmbeddingsStore(embedding_store_loc)
        self.meg_store = MEGFeatureMapStore(meg_feature_map_loc)
        self.delays = delays
        self.mmap = mmap
        self.use_cuda = use_cuda
        self.fold_tensors = fold_tensors
    
    def load_normalization(self, session_story_configs:List[SessionStoryConfig]):
        sums = []
        counts = []
        for session_story_config in session_story_configs:
            embeddings = self.embeddings_store.load_embeddings(session_story_config.story, mmap=True, as_feature_map=False).astype(np.float32)
            sums.append(np.sum(embeddings, axis = 0))
            counts.append(embeddings.shape[0])
        means = np.sum(sums, axis = 0) / sum(counts)
        
        squared_diffs = []
        for session_story_config in session_story_configs:
            embeddings = self.embeddings_store.load_embeddings(session_story_config.story, mmap=True, as_feature_map=False).astype(np.float32)
            squared_diffs.append(np.sum(np.square(embeddings - means[None,:]), axis = 0))
        stds = np.sqrt(np.sum(squared_diffs, axis = 0) / sum(counts))
        return means, stds
    
    def load_PCA_projection_weights(self, session_story_configs, means,stds, components = 0.95, use_cuda = True):
        all_normalized_embeddings = []
        
        for config in session_story_configs:
            embeddings = self.embeddings_store.load_embeddings(config.story, mmap=False, as_feature_map=False).astype(np.float32)
            normalized_embeddings = (embeddings - means[None,:])/stds[None,:]
            all_normalized_embeddings.append(normalized_embeddings)
        full_embeddings = np.concat(all_normalized_embeddings, axis = 0)
        if use_cuda:
            full_embeddings_t = torch.from_numpy(full_embeddings).cuda()
            _, S, V_T = torch.linalg.svd(full_embeddings_t, full_matrices=False)
            S, V_T = S.cpu().numpy(), V_T.cpu().numpy()
        else:
            _, S, V_T = np.linalg.svd(full_embeddings)
        if components < 1.0:
            component_percent_var = np.cumsum(S**2)/sum(S**2)
            top_components = np.argmax(component_percent_var > components)
        else:
            top_components = int(components)
        return V_T[:,:top_components].T.astype(np.float16)
    
    def load_configs(self, session_story_configs, embeddings_affine_bias = None, embeddings_affine_scaling = None, pca_weights = None):
        readers = []
        for session_story_config in session_story_configs:
            delay_reader = EmbeddingsFeatureDelayReader(session_story_config, self.embeddings_store,
                                         self.meg_store, mmap = self.mmap,
                                         delays = self.delays,
                                         embeddings_affine_bias=embeddings_affine_bias,
                                         embeddings_affine_scaling=embeddings_affine_scaling,
                                         use_cuda=self.use_cuda,fold_tensors = self.fold_tensors,
                                         pca_weights = pca_weights)
            readers.append(delay_reader)
        return readers
        
    def load(self, session_story_config:SessionStoryConfig):
        return EmbeddingsFeatureDelayReader(session_story_config, self.embeddings_store, self.meg_store, mmap = self.mmap, delays = self.delays, use_cuda=self.use_cuda)
    
    
    
        
        