from regression.fir_seq_2_seq_ridge import FIRSeq2SeqRidge
from regression.lm_encoder.embedding_cache import load_embedding_cache
from regression.session_story_configs import subject_train_configs, subject_test_configs
from regression.load_meg_targets import load_meg_targets
from regression.sgd_regression import LinearModel, SGDRegression, LowRankTensorLinearModel, StandardizeData
from regression.numpy_dataset import RegressionNumpyDataset
from regression.onset_linear_model import OnsetLinearModel
from regression.lm_embeddings.embeddings_store import SessionStoryEmbeddingsFeatureLoader
from regression.baseline_features.feature_mixer import feature_mixer
from regression.feature_loader import train_feature_loader, control_subtracted_train_meg, test_feature_loader
from regression.baseline_features.mel_spectrogram import delayed_spectrogram_features
from regression.block_kfold import DatasetBlockKfold
from regression.losses import r2_loss
import numpy as np

import json
import torch
import os

if __name__ == "__main__":
    task_id = int(os.environ["SLURM_ARRAY_TASK_ID"]) 
    ranks = list(range(1,21))
    rank = ranks[task_id]
    subject = "D"
    
    model_save_name = f"subject_{subject}_rank_sweep/rank_{rank}"
    llm_name = "llama2"
    train_stories = subject_train_configs(subject, llm_name)
    test_stories = subject_test_configs(subject, llm_name)

    model_save_loc = f"./runs/{model_save_name}"
    freq_cutoff = None
    onset_bias = False
    
    train_config = {
        "train-stories":[f"{train_story.subject}_{train_story.session}_{train_story.session_story}" for train_story in train_stories],
        "lr":5*1e-3,
        "batch-size":300000,
        "llm-features":{ "name": llm_name,"layer": 3,"context": 20,"pca": 0.95,"load": True,"delays": 40 },
        "ridge-params":[0.0, 0.1, 1.0, 10.0],#[0.0, 0.1, 1.0, 10.0],
        "kfolds":6,
        "num-workers":64,
        "max-epochs":500,
        "rank":rank,
        "use-bias":True,
        "llm_name":llm_name,
        "subject":subject,
        "delays":40,
        "k_runs":5
    }
    num_layers = 1
    if llm_name == "llama2":
        layer_size = 4096
    else:
        layer_size = 768
    
    os.makedirs(model_save_loc, exist_ok=True)
    with open(model_save_loc + "/train-config.json", "w+") as f:
        json.dump(train_config, f)

    meg = load_meg_targets(train_stories, frequency_cutoff=freq_cutoff, load_noise=False)
    feature_size, regression_features = train_feature_loader(train_config["llm-features"], lm_feature_map_loc=f"./embeddings/embeddings_sweep/layer_{train_config['llm-features']['layer']}_context_{train_config['llm-features']['context']}", 
                                                             subject = train_config["subject"], controls = [], delays = [], 
                                                             force_load=True, load_as_control=False)
    torch.cuda.empty_cache()
    
    dataset = RegressionNumpyDataset(regression_features, meg)
    dataset_folds = DatasetBlockKfold(dataset.dataset_lens, n_splits = train_config["kfolds"], block_shuffle = True)
    
    ridge_losses = []
    for ridge_index, ridge in enumerate(train_config["ridge-params"]):
        total_validation_loss = 0
        for fold_index, (train_indices, validation_indices) in enumerate(dataset_folds.split()):
            fold_save_loc = model_save_loc + "/" + f"ridge_{ridge_index}_fold_{fold_index}"
            os.makedirs(fold_save_loc, exist_ok = True)
            train_dataset = dataset.subset(train_indices)
            validation_dataset = dataset.subset(validation_indices)
            model = LowRankTensorLinearModel(train_config["rank"], train_config["delays"], feature_size, 306,
                                            use_bias = train_config["use-bias"], mean = None, std = None, onset_bias=onset_bias)
            optimization = SGDRegression(model, ridge_params = train_config["ridge-params"][ridge_index], device="cuda")
            optimization.fit(train_dataset, num_workers=train_config["num-workers"], batch_size = train_config["batch-size"],
                            max_epochs=train_config["max-epochs"], model_save_loc=fold_save_loc, lr = train_config["lr"], amsgrad=False)
            predicted, validation = optimization.predict(validation_dataset, 5000)
            mean_r2_loss = np.mean(r2_loss(predicted, validation))
            total_validation_loss += mean_r2_loss
        ridge_losses.append(total_validation_loss)
    
    best_ridge_index = np.argmax(ridge_losses, axis=0)
    for k in range(train_config["k_runs"]):
        final_save_loc = model_save_loc + "/" + f"final_model_run_{k}"
        os.makedirs(final_save_loc, exist_ok = True)
        model = LowRankTensorLinearModel(train_config["rank"], train_config["delays"], feature_size, 306,
                                                use_bias = train_config["use-bias"], mean = None, std = None, onset_bias=onset_bias)
        optimization = SGDRegression(model, ridge_params = train_config["ridge-params"][best_ridge_index], device="cuda")
        optimization.fit(dataset, num_workers=train_config["num-workers"], batch_size = train_config["batch-size"],
                        max_epochs=train_config["max-epochs"], model_save_loc=final_save_loc, lr = train_config["lr"], amsgrad=False)
        os.makedirs(final_save_loc, exist_ok=True)
        with open(final_save_loc + "/train-config.json", "w+") as f:
            json.dump({"ridge":train_config["ridge-params"][best_ridge_index]}, f)
    