from regression.fir_seq_2_seq_ridge import FIRSeq2SeqRidge
from regression.lm_encoder.embedding_cache import load_embedding_cache
from regression.session_story_configs import subject_train_configs, subject_test_configs
from regression.load_meg_targets import load_meg_targets
from regression.sgd_regression import LinearModel, SGDRegression, LowRankTensorLinearModel, StandardizeData
from regression.numpy_dataset import RegressionNumpyDataset
from regression.onset_linear_model import OnsetLinearModel
from regression.lm_embeddings.embeddings_store import SessionStoryEmbeddingsFeatureLoader
from regression.baseline_features.feature_mixer import feature_mixer
from regression.feature_loader import train_feature_loader, control_subtracted_train_meg, test_feature_loader
from regression.baseline_features.mel_spectrogram import delayed_spectrogram_features
from regression.block_kfold import DatasetBlockKfold
from regression.losses import r2_loss
import numpy as np

import json
import torch
import os

if __name__ == "__main__":
    task_id = int(os.environ["SLURM_ARRAY_TASK_ID"]) 
    ranks = list(range(1,20))
    rank = ranks[task_id]
    subject = "A"
    model_save_name = f"subject_{subject}_rank_sweep_single/rank_{rank}"
    llm_name = "llama2"
    dataset_loc = "./data"
    train_stories = subject_train_configs(subject, llm_name)
    test_stories = subject_test_configs(subject, llm_name)

    model_save_loc = f"./runs/{model_save_name}"
    freq_cutoff = None
    onset_bias = False
    
    train_config = {
        "train-stories":[f"{train_story.subject}_{train_story.session}_{train_story.session_story}" for train_story in train_stories],
        "lr":5*1e-3,
        "batch-size":300000,
        "llm-features":{ "name": llm_name,"layer": 3,"context": 20,"pca": 0.95,"load": True,"delays": 40 },
        "ridge-params":0.1,
        "num-workers":64,
        "max-epochs":1200,
        "rank":rank,
        "use-bias":True,
        "llm_name":llm_name,
        "subject":subject,
        "delays":40,
    }
    num_layers = 1
    if llm_name == "llama2":
        layer_size = 4096
    else:
        layer_size = 768
    
    os.makedirs(model_save_loc, exist_ok=True)
    with open(model_save_loc + "/train-config.json", "w+") as f:
        json.dump(train_config, f)

    meg = load_meg_targets(train_stories, frequency_cutoff=freq_cutoff, load_noise=False)
    feature_size, regression_features = train_feature_loader(train_config["llm-features"],
                                                             lm_feature_map_loc=f"./embeddings/embeddings_sweep/layer_{train_config['llm-features']['layer']}_context_{train_config['llm-features']['context']}", 
                                                             subject = train_config["subject"], controls = [], delays = [], 
                                                             force_load=True, load_as_control=False, use_cuda=True)
    torch.cuda.empty_cache()
    
    dataset = RegressionNumpyDataset(regression_features, meg)
    #dataset_folds = DatasetBlockKfold(dataset.dataset_lens, n_splits = train_config["kfolds"], block_shuffle = True)
    model = LowRankTensorLinearModel(train_config["rank"], train_config["delays"], feature_size, 306,
                                            use_bias = train_config["use-bias"], mean = None, std = None, onset_bias=onset_bias)
    optimization = SGDRegression(model, ridge_params = train_config["ridge-params"], device="cuda")
    optimization.fit(dataset, num_workers=train_config["num-workers"], batch_size = train_config["batch-size"],
                    max_epochs=train_config["max-epochs"], model_save_loc=model_save_loc, lr = train_config["lr"], amsgrad=False)
    