import os
from pathlib import Path
from typing import OrderedDict
import yaml
import model 
import data_handler_inet
import train 
import utils 
import torch
import argparse

from torch.utils.tensorboard import SummaryWriter

def pretrain_hvae(config, data_handler, subject_file_idx, ROI_split, writter, hvae_params, train_params): # subject and ROI not really needed here
                                                        # current impl of DataHandler class requires them 
    # Get data
    train_dl, test_dl = data_handler.make_dataloaders(config, 
                        subject_file_idx, config["splits"][RUN][ROI_split],  mode="encdec")

    encoder_split_inds = config["splits_config"][RUN][ROI_split]["encoder_split_inds"]
    encoder_filters_at_split = config["splits_config"][RUN][ROI_split]["encoder_filters_at_split"]
    
    # Init model
    hvae = model.HVAE(out_size=config["img_size"][RUN],
                        fc_hidden1=hvae_params["fc_hidden1"], 
                        fc_hidden2=hvae_params["fc_hidden2"],
                        CNN_embed_dim=hvae_params["CNN_embed_dim"],
                        dropout_rate=hvae_params["dropout_rate"],
                        train_vgg=hvae_params["train_vgg"],
                        encoder_split_inds=encoder_split_inds,
                        encoder_filters_at_split=encoder_filters_at_split).cuda()    
    # Loss
    loss_fn = utils.vae_loss
    
    # Train 
    train.train(hvae, loss_fn=loss_fn, ds=RUN,
                    train_dl=train_dl, test_dl=test_dl, params=train_params, writter=writter)
    
    # Save
    fpath = os.path.join(config["models_dir"], "run_{}".format(RUN))
    fname = "encdec_ROIS_{}.h5".format(ROI_split)
    Path(fpath).mkdir(parents=True, exist_ok=True)
    torch.save(hvae.state_dict(), os.path.join(fpath, fname))

def train_neural_decoder(config, data_handler, subject_file_idx, ROI_split, writter, ndec_params, hvae_params, train_params):
    # Load pre-trained encode/decoder
    fname = "encdec_ROIS_{}.h5".format(ROI_split)
    model_load_filename = os.path.join(config["models_dir"], "run_{}".format(RUN), fname)
    weights_file = os.path.join(config["models_dir"], model_load_filename)
    
    # Get dataloaders
    train_dl, test_dl = data_handler.make_dataloaders(config, 
                        subject_file_idx, config["splits"][RUN][ROI_split], mode="dec")
    
    # Find number of voxels in each roi, this is the input size
    num_voxels_per_roi = OrderedDict()
    sample, _ = next(iter(train_dl))
    for roi in sample:
        num_voxels_per_roi[roi] = sample[roi].shape[1]

    # Init model
    encoder_split_inds = config["splits_config"][RUN][ROI_split]["encoder_split_inds"]
    encoder_filters_at_split = config["splits_config"][RUN][ROI_split]["encoder_filters_at_split"]
    hvae_params["encoder_split_inds"] = encoder_split_inds
    hvae_params["encoder_filters_at_split"] = encoder_filters_at_split
    hvae_params["out_size"] = config["img_size"][RUN]
    
    neural_model = model.DecodeModel(hvae_weights=weights_file, 
                    num_voxels_per_roi=num_voxels_per_roi,
                    hvae_params=hvae_params,
                    ndec_params=ndec_params).cuda()

    # Loss function
    loss_fn = utils.neural_decoder_loss()

    # Train
    train.train(neural_model, loss_fn=loss_fn, ds=RUN,
                    train_dl=train_dl, test_dl=test_dl, params=train_params, writter=writter)
    
    # Save it 
    fpath = os.path.join(config["models_dir"], "run_{}".format(RUN))
    fname = "ndec_S_{}_ROIS_{}.h5".format(subject_file_idx+1, ROI_split)
    Path(fpath).mkdir(parents=True, exist_ok=True)
    torch.save(neural_model.state_dict(), os.path.join(fpath, fname))

def main(args):
    assert os.path.exists("./config.yaml"), "No config.yaml found"

    hvae_params = {
        "fc_hidden1": 1024, 
        "fc_hidden2": 512, 
        "CNN_embed_dim": 128,
        "dropout_rate": 0,
        "train_vgg": False,
    }
    ndec_params = {
        "fc_hidden1": 1024, 
        "fc_hidden2": 512, 
        "dropout_rate": 0.1
    }

    train_params = {
        "epochs": 1001,
        "lr": 5e-5,
        "test_every": 50,
        "save_every": 50
    }

    global RUN
    RUN = args.dataset 
    assert RUN == "inet", "Only inet"
    
    data_handler = data_handler_inet
    subjects = [0,1,2,3,4]
    
    with open("./config.yaml", "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

        for ROI_split in config["splits"][RUN]:

            if args.train == "hvae" or args.train == "both":
                # # Pretrain the HVAE 
                print("Training HVAE for ROI split = {}".format(ROI_split))
                writter = SummaryWriter("results/{}/encdec/ROI_{}".format(RUN, ROI_split))
                pretrain_hvae(config, 
                                data_handler=data_handler, 
                                subject_file_idx=1, 
                                ROI_split=ROI_split, 
                                writter=writter,
                                hvae_params=hvae_params, 
                                train_params=train_params)

            if args.train == "ndec" or args.train == "both":
                for sub_idx in subjects:
                    print("Training Neural Decoder for Subject = {} and ROI split = {}".format(sub_idx, ROI_split))
                    tb_path = "results/{}/ndec/S_{}/ROI_{}".format(RUN, sub_idx, ROI_split)
                    writter = SummaryWriter(tb_path)
                    train_neural_decoder(config, 
                                        data_handler=data_handler,
                                        subject_file_idx=sub_idx, 
                                        ROI_split=ROI_split, 
                                        writter=writter, 
                                        train_params=train_params,
                                        hvae_params=hvae_params,
                                        ndec_params=ndec_params)
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("-ds", "--dataset", help="dataset to use, inet or vim")
    parser.add_argument("-g", "--gpus", help="gpus to use, set to -1 to use all", default="-1")
    parser.add_argument("-t", "--train", help="train hvae, ndec or both")
    args = parser.parse_args()
    if args.gpus != -1:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus

    main(args)



