
from regression.session_story_configs import subject_train_configs, subject_test_configs
from regression.load_meg_targets import load_meg_targets
from regression.sgd_regression import LinearModel, SGDRegression, LowRankTensorLinearModel, StandardizeData
from regression.numpy_dataset import RegressionNumpyDataset
from regression.onset_linear_model import OnsetLinearModel
from regression.lm_embeddings.embeddings_store import SessionStoryEmbeddingsFeatureLoader
from regression.baseline_features.feature_mixer import feature_mixer
from regression.feature_loader import train_feature_loader, control_subtracted_train_meg
from regression.baseline_features.mel_spectrogram import delayed_spectrogram_features
from regression.regression_utils import save_pkl
import numpy as np
import json
import torch
import os
import shutil
import sys
from tqdm import tqdm
from regression.regression_closed_form import ridge_regression

#SBATCH --array=0-6

if __name__ == "__main__":
    print("started")
    llm_name = "llama2"
    layer = 3
    subject = "A"
    train_stories = subject_train_configs(subject, llm_name)
    test_stories = subject_test_configs(subject, llm_name)
    freq_cutoff = None
    onset_bias = False
    
    train_config = {
        "train-stories":[f"{train_story.subject}_{train_story.session}_{train_story.session_story}" for train_story in train_stories],
        "llm-features":{"name":"llama2", "layer":3, "context":20, "pca":0.95, "load":False, "delays":40},
        "llm_name":llm_name,
        "delays":[],
    }
    
    num_layers = 1
    if llm_name == "llama2":
        layer_size = 4096
    else:
        layer_size = 768
    save_folder = "./runs/controls"
    os.makedirs(save_folder, exist_ok=True)
    
    control_sets = [["spectrogram"], ["word_onset"], ["sentence_start"], ["spectrogram","word_onset","sentence_start"]]
    save_locs = ["spectrogram_subtracted", "word_onset_subtracted", "sentence_start_subtracted", "all_controls_subtracted"]
    delay_sets = [[15], [40], [40], [15, 40, 40]]
    for i, (control_set, save_loc, delay_set) in enumerate(zip(control_sets, save_locs, delay_sets)):
        meg = load_meg_targets(train_stories + test_stories, frequency_cutoff=freq_cutoff, load_noise=False)
        feature_size, regression_features = train_feature_loader(train_config["llm-features"], lm_feature_map_loc=f"./embeddings/embeddings_sweep/layer_{train_config['llm-features']['layer']}_context_{train_config['llm-features']['context']}", 
                                                             subject = subject, controls = control_set, delays = delay_set, 
                                                             force_load=True, add_test=True, load_as_control=False)
        X = np.concat(regression_features, axis=0)
        Y = np.concat(meg, axis=0)
        Ws = ridge_regression(X, Y, [0.0], block_size = 1000, mean_mse_loss=False)
        save_pkl(Ws, save_folder + "/" + save_loc + ".pkl")
    