from regression.lm_embeddings.embeddings_store import SessionStoryEmbeddingsFeatureLoader
from regression.session_story_configs import subject_train_configs, subject_test_configs
from regression.baseline_features.mel_spectrogram import delayed_spectrogram_features
from regression.baseline_features.word_onset import delayed_word_onset_features
from regression.baseline_features.sentence_effects import delayed_sentence_effects_features
from regression.baseline_features.feature_mixer import feature_mixer2 as feature_mixer
from regression.load_meg_targets import load_meg_targets
import os
from tqdm import tqdm
import pickle as pkl
import numpy as np
import torch
from regression.regression_closed_form import block_gpu_multiply

def save_pkl(obj, file_loc):
    with open(file_loc, "wb") as f:
        pkl.dump(obj, f)

def load_pkl(file_loc):
    with open(file_loc, "rb") as f:
        return pkl.load(f)

def train_feature_loader(llm_features = {"name":"llama2", "layer":3, "context":20, "pca":0.95, "load":True, "delays":40}, controls = ["spectogram", "word_onset", "sentence_start"],
                         delays = [15, 40, 40],lm_feature_map_loc = "./embeddings", subject = "A", force_load = False,
                         llm_load_save_loc = "./embeddings_transform_cache", verbose = True,
                         add_test = False, load_as_control = False, use_cuda = False,
                         dataset_loc = "./dataset"):
    os.makedirs(llm_load_save_loc, exist_ok=True)
    train_configs = subject_train_configs(subject, dataset_loc=dataset_loc)
    if add_test:
        train_configs = train_configs + subject_test_configs(subject, dataset_loc=dataset_loc)
    if load_as_control:
        add_test = True
    meg_store_loc = lm_feature_map_loc + "/meg_store"
    lm_embeddings_loc = lm_feature_map_loc+f"/{llm_features['name']}/layer_{llm_features['layer']}"
    feature_list = []
    if not llm_features is None and llm_features["load"]:
        llm_loader = SessionStoryEmbeddingsFeatureLoader(meg_store_loc, lm_embeddings_loc,
                                            use_cuda=use_cuda, fold_tensors=False, delays = llm_features['delays'])
        
        mean_std_save_loc = llm_load_save_loc + "/" +f"MEANSTD_{llm_features['name']}_{llm_features['layer']}_{llm_features['context']}_with_test_{add_test}_delay_{llm_features['delays']}.pkl"
        if not os.path.exists(mean_std_save_loc):
            print("No Mean/Std detected, building from embeddings")
            mean, std = llm_loader.load_normalization(train_configs)
            save_pkl([mean,std], mean_std_save_loc)
        else:
            mean, std = llm_loader.load_normalization(train_configs)
        if not llm_features["pca"] is None:
            pca_save_loc = llm_load_save_loc + "/" +f"PCA_{llm_features['name']}_{llm_features['layer']}_{llm_features['context']}_pca{llm_features['pca']}_with_test_{add_test}_delay_{llm_features['delays']}.pkl"
            if not os.path.exists(pca_save_loc):
                print("No PCA weights detected, building from embeddings")
                pca_weights = llm_loader.load_PCA_projection_weights(train_configs, mean, std, llm_features["pca"], use_cuda=True)
                save_pkl(pca_weights, pca_save_loc)
            else:
                pca_weights = load_pkl(pca_save_loc)
        else:
            pca_weights = None
        llm_features = llm_loader.load_configs(train_configs, -mean, 1/std, pca_weights)
        feature_list.append(llm_features)
    assert len(controls) == len(delays)
    for control_name, delay in zip(controls, delays):
        print(control_name)
        print(delay)
        if control_name == "spectrogram":
            new_features = delayed_spectrogram_features(train_configs, delay, fold_tensors=False, zero_out_delays=None)
        elif control_name == "word_onset":
            new_features = delayed_word_onset_features(meg_store_loc, train_configs, delay, fold_tensors=False)
        elif control_name == "sentence_start":
            new_features = delayed_sentence_effects_features(train_configs, lm_embeddings_loc,
                                                             meg_store_loc,delay,fold_tensors=False)
        else:
            raise RuntimeError("Control Name Unknown")
        feature_list.append(new_features)
    
    mixed_feature_loader = feature_mixer(feature_list)
    fixed_delay_feature_size = mixed_feature_loader[0].const_delay_feature_size

    if force_load:
        if verbose:
            iterator = tqdm(mixed_feature_loader, desc="Building Train Features:")
        else:
            iterator = mixed_feature_loader
        return fixed_delay_feature_size, [x.force_load_array() for x in iterator]
    return fixed_delay_feature_size, mixed_feature_loader

def test_feature_loader(llm_features = {"name":"llama2", "layer":3, "context":20, "pca":0.95, "load":True, "delays":40}, controls = ["spectogram", "word_onset", "sentence_start"],
                   lm_feature_map_loc = "./embeddings", subject = "A", delays = [15, 40, 40], force_load = False, llm_load_save_loc = "./embeddings_transform_cache",
                   load_as_control = False, dataset_loc = "./data"):
    os.makedirs(llm_load_save_loc, exist_ok=True)
    train_configs = subject_train_configs(subject, dataset_loc=dataset_loc)
    test_configs = subject_test_configs(subject, dataset_loc=dataset_loc)
    if load_as_control:
        add_test = True
        train_configs = train_configs + test_configs
    else:
        add_test = False
    meg_store_loc = lm_feature_map_loc + "/meg_store"
    lm_embeddings_loc = lm_feature_map_loc+f"/{llm_features['name']}/layer_{llm_features['layer']}"
    feature_list = []
    if not llm_features is None and llm_features["load"]:
        llm_loader = SessionStoryEmbeddingsFeatureLoader(meg_store_loc, lm_embeddings_loc,
                                            use_cuda=True, fold_tensors=False, delays = llm_features['delays'])
        
        mean_std_save_loc = llm_load_save_loc + "/" +f"MEANSTD_{llm_features['name']}_{llm_features['layer']}_{llm_features['context']}_with_test_{add_test}_delay_{llm_features['delays']}.pkl"
        if not os.path.exists(mean_std_save_loc):
            print("No Mean/Std detected, building from embeddings")
            mean, std = llm_loader.load_normalization(train_configs)
            save_pkl([mean,std], mean_std_save_loc)
        else:
            mean, std = llm_loader.load_normalization(train_configs)
        if not llm_features["pca"] is None:
            pca_save_loc = llm_load_save_loc + "/" +f"PCA_{llm_features['name']}_{llm_features['layer']}_{llm_features['context']}_pca{llm_features['pca']}_with_test_{add_test}_delay_{llm_features['delays']}.pkl"
            if not os.path.exists(pca_save_loc):
                print("No PCA weights detected, building from embeddings")
                pca_weights = llm_loader.load_PCA_projection_weights(train_configs, mean, std, llm_features["pca"], use_cuda=True)
                print(pca_weights.shape)
                save_pkl(pca_weights, pca_save_loc)
            else:
                pca_weights = load_pkl(pca_save_loc)
        else:
            pca_weights = None
        llm_features = llm_loader.load_configs(test_configs, -mean, 1/std, pca_weights)
        feature_list.append(llm_features)
        
    assert len(controls) == len(delays)
    for control_name, delay in zip(controls, delays):
        if control_name == "spectrogram":
            new_features = delayed_spectrogram_features(test_configs, delay, fold_tensors=False, zero_out_delays=None)
        elif control_name == "word_onset":
            new_features = delayed_word_onset_features(meg_store_loc, test_configs, delay, fold_tensors=False)
        elif control_name == "sentence_start":
            new_features = delayed_sentence_effects_features(test_configs, lm_embeddings_loc,
                                                             meg_store_loc,delay,fold_tensors=False)
        else:
            raise RuntimeError("Control Name Unknown")
        feature_list.append(new_features)
    
    mixed_feature_loader = feature_mixer(feature_list)
    fixed_delay_feature_size = mixed_feature_loader[0].const_delay_feature_size
    if force_load:
        return fixed_delay_feature_size, [x.force_load_array() for x in tqdm(mixed_feature_loader, desc="Building Test Features:")]
    return fixed_delay_feature_size, mixed_feature_loader

def load_embedding_transform(llm_features = {"name":"llama2", "layer":3, "context":20, "pca":0.95}, llm_load_save_loc = "./embeddings_transform_cache",
                             add_test=False, use_torch=False, torch_device = "cuda"):
    mean_std_save_loc = llm_load_save_loc + "/" +f"MEANSTD_{llm_features['name']}_{llm_features['layer']}_{llm_features['context']}_with_test_{add_test}_delay_{llm_features['delays']}.pkl"
    mean, std = load_pkl(mean_std_save_loc)
    print("loaded means")
    if not llm_features["pca"] is None:
        pca_save_loc = llm_load_save_loc + "/" +f"PCA_{llm_features['name']}_{llm_features['layer']}_{llm_features['context']}_pca{llm_features['pca']}_with_test_{add_test}_delay_{llm_features['delays']}.pkl"
        pca_weights = load_pkl(pca_save_loc)
    if use_torch:
        mean = torch.from_numpy(mean).to(torch.float32).to(torch_device)
        std = torch.from_numpy(std).to(torch.float32).to(torch_device)
        pca_weights = torch.from_numpy(pca_weights).to(torch.float32).to(torch_device)
        
    def _transform(embeddings):
        out = (embeddings - mean[None,:])/std[None,:]
        if not llm_features["pca"] is None:
            if use_torch:
                out = torch.matmul(out, pca_weights.T)
            else:  
                out = np.matmul(out, pca_weights.T)
        return out
    return _transform

def control_subtracted_train_meg(control_weight_loc, control_features = ["spectrogram", "word_onset", "sentence_start"],
                           subject = "A", delays = [15, 40, 40],
                           llm_features = {"name":"llama2", "layer":3, "context":20, "pca":0.95, "load":False, "delays":40},
                           use_cuda = True, llm_feature_map_loc = "./embeddings", dataset_loc = "./data", llm_load_save_loc = "./embeddings_transform_cache"):
    
    _, train_features = train_feature_loader(llm_features, control_features, subject=subject,
                                             delays=delays, force_load=True,
                                             load_as_control=True, add_test = False, use_cuda=use_cuda,
                                             llm_load_save_loc=llm_load_save_loc,
                                             llm_feature_map_loc=llm_feature_map_loc,
                                             dataset_loc=dataset_loc
                                             )
    W = list(load_pkl(control_weight_loc))[0][1]
    meg = load_meg_targets(subject_train_configs(subject, dataset_loc))
    out = []
    for story_i in range(len(meg)):    
        meg_control_predict = train_features[story_i] @ W
        out.append(meg[story_i] - meg_control_predict)
    return out

def control_subtracted_test_meg(control_weight_loc, control_features = ["spectrogram", "word_onset", "sentence_start"],
                           subject = "A", delays = [15, 40, 40],
                           llm_features = {"name":"llama2", "layer":3, "context":20, "pca":0.95, "load":False, "delays":40},
                           use_cuda = True,llm_feature_map_loc = "./embeddings", dataset_loc = "./data", llm_load_save_loc = "./embeddings_transform_cache"
                           ):
    _, test_features = test_feature_loader(llm_features, control_features, subject=subject, delays=delays,
                                           force_load=True, load_as_control=True, use_cuda=use_cuda,
                                           llm_load_save_loc=llm_load_save_loc,
                                           llm_feature_map_loc=llm_feature_map_loc,
                                           dataset_loc=dataset_loc)
    W = list(load_pkl(control_weight_loc))[0][1]
    meg = load_meg_targets(subject_test_configs(subject, dataset_loc))
    out = []
    for story_i in range(len(meg)):    
        meg_control_predict = test_features[story_i] @ W
        out.append(meg[story_i] - meg_control_predict)
    return out
