
from SemanticModel import SemanticSentenceModel
from matplotlib.pyplot import figure, cm
import numpy as np
import logging
import tqdm
from DataSequence import DataSequence
logging.basicConfig(level=logging.DEBUG)
from stimulus_utils import load_grids_for_stories
from stimulus_utils import load_generic_trfiles
from stimulus_utils import load_simulated_trfiles

from dsutils import make_word_ds, make_phoneme_ds

from joblib import Parallel, delayed
from sklearn.linear_model import RidgeCV
import numpy as np
from os.path import join as opj
import os
import tables
import json
import h5py

from os.path import join 
import transformers
import torch
from huggingface_hub import notebook_login, login
import os
import seaborn as sns
import pickle
import nibabel as nib
from npp import zscore
import torch
from transformers import AutoTokenizer, AutoModel
# from utils_ridge.ridge import ridge, bootstrap_ridge
from himalaya.ridge import RidgeCV


os.environ["HF_TOKEN"]="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"



# sub="sub-03"
sub = "S1"
# save
models_path = os.path.join("models", sub)
os.makedirs(models_path, exist_ok = True)


encode_stories = False
# parallel =True

TRIM = 5
STIM_DELAYS = [1, 2, 3, 4]
RESP_DELAYS = [-4, -3, -2, -1]
ALPHAS = np.logspace(1, 3, 10)
NBOOTS = 50
VOXELS = 10000
CHUNKLEN = 40

DATA_DIR = "/home/matteo/tutorial_language_fmri/semantic-decoding/data_train"
EM_DATA_DIR="../deep-fMRI-dataset-master/em_data"

context_window = 5

print("sub", sub)



def fit_and_predict(voxel_idx, X_train, z_train, X_test, z_test):
    # Instantiate and fit the model
    model = RidgeCV(alphas=[1e-3, 1e-2, 1e-1, 1, 10, 100,1e3], cv=5)
    
    # z_train = np.nan_to_num(z_train)
    # X_train = np.nan_to_num(z_train)
    
    model.fit(np.nan_to_num(z_train), np.nan_to_num(X_train))
    
    # Predict on test data
    y_pred = model.predict(np.nan_to_num(z_test))
    corr = np.corrcoef(y_pred, X_test)[0, 1]
    # wandb.log({"corr":corr,"voxel_idx":voxel_idx})
    
    return model, corr

def get_response(stories, subject, ):
	"""Get the subject"s fMRI response for stories."""
	subject_dir = join(DATA_DIR, "train_response/%s" % subject)
	
	# main_path = pathlib.Path(__file__).parent.parent.resolve()

	# base = os.path.join(main_path, subject_dir)
	resp = []
	for story in stories:
		resp_path = os.path.join(subject_dir, "%s.hf5" % story)
		hf = h5py.File(resp_path, "r")
		resp.extend(hf["data"][:])
		hf.close()
	return np.array(resp)





# semantic_model = SemanticSentenceModel(device="cuda:5")


sessions = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 20]
stories = []
with open(os.path.join(DATA_DIR, "sess_to_story.json"), "r") as f:
    sess_to_story = json.load(f) 
for sess in sessions:
    stories.extend(sess_to_story[str(sess)])

train_stories_len = 70

train_stories = stories[:train_stories_len]
test_stories = stories[train_stories_len:]

Rstories = train_stories
Pstories = test_stories
allstories = train_stories + test_stories


# Load TextGrids

grids = load_grids_for_stories(stories,grid_dir=opj(DATA_DIR, "train_stimulus"))

# Load TRfiles
# trfiles = load_generic_trfiles(allstories,grid_dir=opj(DATA_DIR, "ds003020/derivative/TextGrids"))
with open(join(DATA_DIR, "ds003020/derivative/respdict.json"), "r") as f:
    respdict = json.load(f)
trfiles = load_simulated_trfiles(respdict)
# Make word and phoneme datasequences
wordseqs = make_word_ds(grids, trfiles) # dictionary of {storyname : word DataSequence}
phonseqs = make_phoneme_ds(grids, trfiles) # dictionary of {storyname : phoneme DataSequence}



# Set up logging
logger = logging.getLogger(__name__)

class ContextualSemanticModel:
    """This class defines a semantic vector-space model using a pre-trained language model
    to obtain contextual word embeddings.

    It contains two important variables: vocab and data.
    vocab is a 1D list (or array) of words.
    data is a 2D array (features by words) of word-feature values.
    """
    def __init__(self, model_name: str,  context_window: int = 5, device ="cuda:0"):
        """Initializes a ContextualSemanticModel with the given model name and vocabulary."""
        
        self.context_window = context_window
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(device)



    def get_word_embedding(self, word: str, context: list[str] = None, method: str = 'last') -> np.ndarray:
        
        with torch.no_grad():
            """Returns the embedding of the given word within the optional context."""
            if context is None:
                context = []

            # Construct the input text
            context_text = ' '.join(context[-self.context_window:]) + ' ' + word
            inputs = self.tokenizer(context_text, return_tensors='pt')
            input_ids = inputs.input_ids.to(self.device)
            outputs = self.model(input_ids,output_attentions=True)

            # Get the embeddings for all tokens
            word_embedding = outputs.last_hidden_state[0].detach().cpu().numpy()


        # word_embedding = np.array(token_embeddings)
        if method == 'mean':
            word_embedding = word_embedding.mean(axis=0)
        elif method == 'sum':
            word_embedding = word_embedding.sum(axis=0)
        elif method == 'concat':
            word_embedding = word_embedding.flatten()
        elif method == 'weighted_sum':


        # Use attention weights for weighted sum
            attention_weights = outputs.attentions[-1][0, :, :, :].detach().cpu().numpy()
            # Average over heads
            word_weights = attention_weights.mean(axis=0)
            # Average over tokens
            word_weights = word_weights.mean(axis=0) 
            word_embedding = np.average(word_embedding, axis=0, weights=attention_weights)
        elif method == 'last':
            word_embedding = word_embedding[-1]
        else:
            raise ValueError(f"Unknown aggregation method: {method}")

        
        return word_embedding


    def __getitem__(self, word: str) -> np.ndarray:
        """Returns the vector corresponding to the given [word]."""
        return self.data[:, self.vindex[word]]

    def similarity(self, word1: str, context1: list[str], word2: str, context2: list[str], method: str = 'last') -> float:
        """Returns the cosine similarity between the vectors for [word1] and [word2] given their contexts."""
        vec1 = self.get_word_embedding(word1, context1, method)
        vec2 = self.get_word_embedding(word2, context2, method)
        return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))


def make_semantic_model(dataseq_story,semantic_model):
    
    embeddings = []
    sentences=[]

    for i in tqdm.trange(len(dataseq_story.data)):
        word = dataseq_story.data[i]
        context = dataseq_story.data[i-context_window:i]
        sentences.append(" ".join(context)+" "+word)
        
        #actually compute embeddings
        embedding = semantic_model.get_word_embedding(word,context)
        embeddings.append(embedding)
    
        
    return DataSequence(np.stack(embeddings), dataseq_story.split_inds, dataseq_story.data_times, dataseq_story.tr_times), sentences

if encode_stories:
    model_name= "meta-llama/Meta-Llama-3-8B"
    # Initialize the ContextualSemanticModel
    semantic_model = ContextualSemanticModel(model_name=model_name, device = "cuda:2", context_window = context_window)


print("Load stories and stimuli")


sessions = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 20]
stories = []
with open(os.path.join(DATA_DIR, "sess_to_story.json"), "r") as f:
    sess_to_story = json.load(f) 
for sess in sessions:
    stories.extend(sess_to_story[str(sess)])

train_stories_len = 70

train_stories = stories[:train_stories_len]
test_stories = stories[train_stories_len:]

Rstories = train_stories
Pstories = test_stories
allstories = train_stories + test_stories

print("load textgrids")


# Load TextGrids

grids = load_grids_for_stories(stories,grid_dir=opj(DATA_DIR, "train_stimulus"))

# Load TRfiles
# trfiles = load_generic_trfiles(allstories,grid_dir=opj(DATA_DIR, "ds003020/derivative/TextGrids"))
with open(join(DATA_DIR, "ds003020/derivative/respdict.json"), "r") as f:
    respdict = json.load(f)
trfiles = load_simulated_trfiles(respdict)
# Make word and phoneme datasequences
wordseqs = make_word_ds(grids, trfiles) # dictionary of {storyname : word DataSequence}
phonseqs = make_phoneme_ds(grids, trfiles) # dictionary of {storyname : phoneme DataSequence}

## Create the semantic represenataion of the stories

# Project stimuli
semanticseqs = dict() # dictionary to hold projected stimuli {story name : projected DataSequence}
sentence_semanticseqs = dict()
sentences= dict()




tgt_dir=f"data_encoded/{sub}"
os.makedirs(tgt_dir,exist_ok=True)

if encode_stories:
    print("Encoding the stories")

    # Project stimuli
    semanticseqs = dict() # dictionary to hold projected stimuli {story name : projected DataSequence}
    for i,story in enumerate(allstories):
        print(f"Running {story}, {i+1}/{len(allstories)}")
        semanticseqs[story],sent = make_semantic_model(wordseqs[story],semantic_model)
        sentences[story] = sent

    print("Done encoding stories, saving to disk")
    # Save the dictionaries as pickle files
    with open(os.path.join(tgt_dir, 'semanticseqs.pkl'), 'wb') as f:
        pickle.dump(semanticseqs, f)

    with open(os.path.join(tgt_dir, 'sentences.pkl'), 'wb') as f:
        pickle.dump(sentences, f)

    print(f'Dictionaries have been saved in {tgt_dir}')
else:
    
    with open(os.path.join(tgt_dir, 'semanticseqs.pkl'), 'rb') as f:
        semanticseqs = pickle.load(f)

    with open(os.path.join(tgt_dir, 'sentences.pkl'), 'rb') as f:
        sentences = pickle.load(f)

    print(f'Dictionaries have been loaded from {tgt_dir}')


print("Downsampling")

# Downsample stimuli
interptype = "lanczos" # filter type
window = 3 # number of lobes in Lanczos filter

downsampled_semanticseqs = dict() # dictionary to hold downsampled stimuli
for story in tqdm.tqdm(allstories):
    downsampled_semanticseqs[story] = semanticseqs[story].chunksums(interptype, window=window)


print("Stack stimuli")

# Combine stimuli
trim = 5
Rstim = np.vstack([zscore(downsampled_semanticseqs[story][5+trim:-trim]) for story in Rstories])
Pstim = np.vstack([zscore(downsampled_semanticseqs[story][5+trim:-trim]) for story in Pstories])

# Print the sizes of these matrices
print ("Rstim shape: ", Rstim.shape)
print ("Pstim shape: ", Pstim.shape)

def make_delayed(stim, delays, circpad=False):
    """Creates non-interpolated concatenated delayed versions of [stim] with the given [delays] 
    (in samples).
    
    If [circpad], instead of being padded with zeros, [stim] will be circularly shifted.
    """
    nt,ndim = stim.shape
    dstims = []
    for di,d in enumerate(delays):
        dstim = np.zeros((nt, ndim))
        if d<0: ## negative delay
            dstim[:d,:] = stim[-d:,:]
            if circpad:
                dstim[d:,:] = stim[:-d,:]
        elif d>0:
            dstim[d:,:] = stim[:-d,:]
            if circpad:
                dstim[:d,:] = stim[-d:,:]
        else: ## d==0
            dstim = stim.copy()
        dstims.append(dstim)
    return np.hstack(dstims)


    # Delay stimuli
ndelays = 4
delays = range(1, ndelays+1)

print ("FIR model delays: ", delays)

delRstim = make_delayed(Rstim, delays)
delPstim = make_delayed(Pstim, delays)

print("Loading response data")



responses={}

responses[sub]={"train": [], "test":[]}
for story in tqdm.tqdm(train_stories):
    responses[sub]["train"].append(get_response([story],sub))

for story in test_stories:
    responses[sub]["test"].append(get_response([story],sub))

zRresp = np.concatenate(responses[sub]["train"])
zPresp = np.concatenate(responses[sub]["test"])

# be sure that no nan are on resp
zRresp = np.nan_to_num(zRresp)
zPresp = np.nan_to_num(zPresp)


# Print matrix shapes
print ("zRresp shape (num time points, num voxels): ", zRresp.shape)
print ("delRstim shape : ", delRstim.shape)

print ("zPresp shape (num time points, num voxels): ", zPresp.shape)
print ("delPstim shape : ", delPstim.shape)


print("[INFO] Start training voxelwise encoding models")

model  = RidgeCV(alphas=[ 1, 10, 100,1e3, 1e4], cv=5).fit(delRstim, zRresp)

import pickle

with open(opj(models_path,f"voxel_models_{sub}_llama_himalaya.pkl"),"wb") as f:
    pickle.dump(model,f)
print("[INFO] Encoding models saved..")

#compute the correlation
# zRpred = model.predict(delRstim)

print("Running predictions")
zPpred = model.predict(delPstim)

print("Computing correlations")
#compute voxelwise correlation
voxel_corrs = np.array([np.corrcoef(zPresp[:,i], zPpred[:,i])[0,1] for i in range(zPresp.shape[1])])




print("Saving output model")




if not os.path.exists(models_path):
    os.makedirs(models_path)


np.save(opj(models_path,f"voxel_corrs_{sub}_llama_himalaya.npy"),np.array(voxel_corrs))

#save the models with pickle
