from regression.session_story_configs import subject_train_configs, subject_test_configs
from regression.load_meg_targets import load_meg_targets
from regression.sgd_regression import LinearModel, SGDRegression, LowRankTensorLinearModel, StandardizeData
from regression.numpy_dataset import RegressionNumpyDataset
from regression.feature_loader import train_feature_loader, control_subtracted_train_meg, test_feature_loader
from regression.block_kfold import DatasetBlockKfold
from regression.losses import r2_loss
from regression.regression_closed_form import RidgeRegression, ridge_regression_ridge_per_channel, FullLinearModel, cross_val_ridge_regression
import numpy as np

import json
import torch
import os

if __name__ == "__main__":
    subject = "C"
    dataset_loc = "./data"
    llm_name = "llama2"
    train_stories = subject_train_configs(subject, dataset_loc)
    test_stories = subject_test_configs(subject, dataset_loc)

    model_save_loc = f"./runs/full_models/subject_{subject}/layer_3"
    freq_cutoff = None
    onset_bias = False
    
    train_config = {
        "train-stories":[f"{train_story.subject}_{train_story.session}_{train_story.session_story}" for train_story in train_stories],
        "llm-features":{ "name": llm_name,"layer": 3,"context": 20,"pca": 0.95,"load": True,"delays": 40 },
        "ridge-params":[0.0, 0.01, 0.001, 0.1, 1.0, 10.0, 100.0, 1000.0],
        "kfolds":6,
        "llm_name":llm_name,
        "subject":subject,
        "delays":40,
    }
    num_layers = 1
    if llm_name == "llama2":
        layer_size = 4096
    else:
        layer_size = 768
    channels = 306
    
    os.makedirs(model_save_loc, exist_ok=True)
    with open(model_save_loc + "/train-config.json", "w+") as f:
        json.dump(train_config, f)

    meg = load_meg_targets(train_stories, frequency_cutoff=freq_cutoff, load_noise=False)
    feature_size, regression_features = train_feature_loader(train_config["llm-features"], lm_feature_map_loc=f"./embeddings/embeddings_sweep/layer_{train_config['llm-features']['layer']}_context_{train_config['llm-features']['context']}", 
                                                             subject = train_config["subject"], controls = [], delays = [], 
                                                             force_load=True, load_as_control=False, use_cuda = True)
    torch.cuda.empty_cache()
    
    dataset = RegressionNumpyDataset(regression_features, meg)
    dataset_folds = DatasetBlockKfold(dataset.dataset_lens, n_splits = train_config["kfolds"], block_shuffle = True)
    
    cross_val_ridge_regression(dataset, train_config["kfolds"], channels, train_config["ridge-params"], model_save_loc, 1000, mean_mse_loss=True)
