import sys
sys.path.append('/playpen-raid/Author/LucidAtlas/')
import matplotlib.pyplot as plt
plt.rcParams['figure.dpi']= 300
import torch
import numpy as np
from utilities import utils
import os
#from hetreg.marglik_lucidatlas_v15 import marglik_for_lucid_v15
#from hetreg.marglik_lucidatlas_v22 import marglik_for_lucid_v22
from hetreg.marglik_lucidatlas_v23 import marglik_for_lucid_v23
import argparse
from pipeline.load import *
import torch.utils.data as data_utils
import model.networks.basics.workspace as ws

def optimize(spec: dict,
             model: torch.nn.Module,
             train_dataloader: data_utils.DataLoader,
             val_dataloader: data_utils.DataLoader=None):

    torch.manual_seed(711)

    # read from experiment settings
    device = spec["Device"] #cuda'
    lr = spec["LearningRate"] #1e-2
    lr_min = spec["LearningRateMin"] #1e-5
    lr_corr = spec.get("LearningRate_corr", lr*10)

    early_stopping = spec["EarlyStopping"] #False
    n_epochs = spec["NumEpochs"] #50 #0
    optimizer = spec["Optimizer"] #'Adam'
    use_wandb = spec["UseWandb"]
    # loss
    likelihood = spec["Loss"]
    RegConfig = spec.get("RegConfig", {})
    # saved checkpoint name
    filename_final_model = spec["SavedFinalCheckpointPath"] + '.pth'
    filename_best_model = spec["SavedBestCheckpointPath"] + '.pth'

    savedir = os.path.join(spec["LoggingRoot"], spec["ExperimentName"]) #'/playpen-raid/Author/LucidAtlas/figures/v12'
    utils.cond_mkdir(savedir)

    model = model.to(device)

    # optimize
    model_final, \
        model_best, \
        margliksh_trained, _, _ = \
        marglik_for_lucid_v23(
        model=model,
        train_loader=train_dataloader,
        valid_loader=val_dataloader,
        likelihood=likelihood,
        reg_config=RegConfig,
        lr=lr,
        lr_corr=lr_corr,
        lr_min=lr_min,
        early_stopping=early_stopping,
        n_epochs=n_epochs,
        scheduler='cos',
        optimizer=optimizer,
        use_wandb=use_wandb
    )
    savedir_ckpt = os.path.join(savedir, ws.model_params_subdir)
    utils.cond_mkdir(savedir_ckpt)
    # save the last epoch
    torch.save(model_final.state_dict(), os.path.join(savedir_ckpt, filename_final_model))
    torch.save(model_best.state_dict(), os.path.join(savedir_ckpt, filename_best_model))
    return model_final



def train_per_fold(specs_filename: str, cv_idx: int=None):
    specs = load_json(specs_filename, cv_idx)
    model = load_model(specs)
    ds_train, ds_train_dataloader = load_dataset(specs, which_split='train')
    ds_val, ds_val_dataloader = load_dataset(specs=specs, which_split='val')

    model_trained = optimize(spec=specs, model=model, train_dataloader=ds_train_dataloader, val_dataloader=ds_val_dataloader)
    # ds_train, ds_train_dataloader = load_dataset(specs, which_split='train_val')
    # model_trained = optimize(spec=specs, model=model, train_dataloader=ds_train_dataloader)
    return model_trained


if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(description="Train a LucidAtlas autodecoder")
    arg_parser.add_argument(
        "--experiment",
        "-e",
        dest="experiment_directory",
        #default='/playpen-raid/Author/LucidAtlas/configs/airways/v1/airway_nam_part.json',
        #default='/playpen-raid/Author/LucidAtlas/configs/OASISBrain/v2/brain_lucidatlas_part.json',
        default='/playpen-raid/Author/LucidAtlas/configs/airways/cv5/airway_lucidatlas_part.json',
        help="The experiment directory. This directory should include "
             + "experiment specifications in 'specs.json', and logging will be "
             + "done in this directory as well.",
    )

    arg_parser.add_argument(
        "--fold",
        "-f",
        dest="which_fold",
        default=1,
        help="whether to vis",
    )


    args = arg_parser.parse_args()

    train_per_fold(args.experiment_directory, args.which_fold)
    print('1')



