
from regression.session_story_configs import subject_train_configs, subject_test_configs
from regression.load_meg_targets import load_meg_targets
from regression.sgd_regression import LinearModel, SGDRegression, LowRankTensorLinearModel, StandardizeData
from regression.numpy_dataset import RegressionNumpyDataset
from regression.onset_linear_model import OnsetLinearModel
from regression.lm_embeddings.embeddings_store import SessionStoryEmbeddingsFeatureLoader
from regression.baseline_features.feature_mixer import feature_mixer
from regression.feature_loader import train_feature_loader#, control_subtracted_train_meg, test_feature_loader
from regression.baseline_features.mel_spectrogram import delayed_spectrogram_features
from regression.block_kfold import DatasetBlockKfold
from regression.losses import r2_loss
import numpy as np
from regression.regression_utils import load_pkl

import json
import torch
import os

def control_subtracted_train_meg(control_weight_loc, control_features = ["spectrogram", "word_onset", "sentence_start"],
                           subject = "A", delays = [15, 40, 40],
                           llm_features = {"name":"llama2", "layer":3, "context":20, "pca":0.95, "load":False, "delays":40},
                           use_cuda = True):
    _, train_features = train_feature_loader(llm_features, control_features, subject=subject,
                                             delays=delays, force_load=True,
                                             load_as_control=True, add_test = False, use_cuda=use_cuda)
    W = load_pkl(control_weight_loc)[0]
    meg = load_meg_targets(subject_train_configs(subject, "llama2"))
    out = []
    for story_i in range(len(meg)):    
        meg_control_predict = train_features[story_i] @ W
        out.append(meg[story_i] - meg_control_predict)
    return out

if __name__ == "__main__":
    controls_index = int(os.environ["SLURM_ARRAY_TASK_ID"]) 
    
    control_sets = [["spectrogram"], ["word_onset"], ["sentence_start"], ["spectrogram","word_onset","sentence_start"]]
    save_locs = ["spectrogram_subtracted", "word_onset_subtracted", "sentence_start_subtracted", "all_controls_subtracted"]
    delay_sets = [[15], [40], [40], [15, 40, 40]]
    
    control_set = control_sets[controls_index]
    save_loc = save_locs[controls_index]
    delay_set = delay_sets[controls_index]
    
    controls_folder = "./runs/controls"
    controls_load_loc = controls_folder + "/" + save_loc + ".pkl"

    subject = "C"
    llm_name = "llama2"
    train_stories = subject_train_configs(subject, llm_name)
    test_stories = subject_test_configs(subject, llm_name)
    rank = 10

    model_save_loc = f"./runs/control_subtracted_models/subject_{subject}_rank_{rank}_{save_loc}_subtracted"
    onset_bias = False
    
    train_config = {
        "train-stories":[f"{train_story.subject}_{train_story.session}_{train_story.session_story}" for train_story in train_stories],
        "lr":5*1e-3,
        "batch-size":300000,
        "llm-features":{ "name": llm_name,"layer": 3,"context": 20,"pca": 0.95,"load": True,"delays": 40 },
        "ridge-params":0.1,
        "num-workers":64,
        "max-epochs":1200,
        "rank":rank,
        "use-bias":True,
        "llm_name":llm_name,
        "subject":subject,
        "delays":40,
    }
    num_layers = 1
    if llm_name == "llama2":
        layer_size = 4096
    else:
        layer_size = 768
    
    os.makedirs(model_save_loc, exist_ok=True)
    with open(model_save_loc + "/train-config.json", "w+") as f:
        json.dump(train_config, f)

    unnormalized_meg = control_subtracted_train_meg(control_weight_loc=controls_load_loc, control_features=control_set, 
                                       subject=subject, delays=delay_set, llm_features={ "name": llm_name,"layer": 3,"context": 20,"pca": 0.95,"load": False,"delays": 40 })
    #renormalize the meg
    meg = []
    for story_meg in unnormalized_meg:
        story_normalized_meg = (story_meg - np.mean(story_meg, axis=0)[None,:])/(np.std(story_meg, axis=0)[None,:])
        meg.append(story_normalized_meg)
    
    feature_size, regression_features = train_feature_loader(train_config["llm-features"], lm_feature_map_loc="./embeddings/embeddings_sweep/layer_3_context_20", 
                                                             subject = train_config["subject"], controls = [], delays = [], 
                                                             force_load=True, load_as_control=False, use_cuda=True)
    torch.cuda.empty_cache()
    
    dataset = RegressionNumpyDataset(regression_features, meg)
    #dataset_folds = DatasetBlockKfold(dataset.dataset_lens, n_splits = train_config["kfolds"], block_shuffle = True)
    model = LowRankTensorLinearModel(train_config["rank"], train_config["delays"], feature_size, 306,
                                            use_bias = train_config["use-bias"], mean = None, std = None, onset_bias=onset_bias)
    optimization = SGDRegression(model, ridge_params = train_config["ridge-params"], device="cuda")
    optimization.fit(dataset, num_workers=train_config["num-workers"], batch_size = train_config["batch-size"],
                    max_epochs=train_config["max-epochs"], model_save_loc=model_save_loc, lr = train_config["lr"], amsgrad=False)