import sys
sys.path.insert(0, "..")

from datasets.datasets import Datasets
import metrics
from models.autoencoder_som import AutoEncoderSOM
from argument_parser import argument_parser
import os
import torch
from utils import utils
import pandas as pd

if __name__ == '__main__':
    '''
    Model Converter (Input)
    python model_converter.py 
    --params-file arguments/autoencoder50_full.lhs 
    --model results/mnist_0.pth 
    --start-idx 0
    
    Input:
        : Params File (lhs file that was used to generate model)
        : Model (.pth of the model)
        : Start idx (index of the lhs paramset)

    Output:
        : Model (.pth in the new format) - Output path is the same as model + "_new.pth"

    '''
    # Argument Parser
    args = argument_parser()

    out_folder = args.out_folder if args.out_folder.endswith("/") else args.out_folder + "/"
    if not os.path.exists(os.path.dirname(out_folder)):
        os.makedirs(os.path.dirname(out_folder), exist_ok=True)

    use_cuda = torch.cuda.is_available() and args.cuda

    if use_cuda:
        torch.cuda.init()

    device = torch.device('cuda:0' if use_cuda else 'cpu')

    params_file_som = args.params_file if args.params_file is not None else "arguments/default_som.lhs"

    parameters = utils.read_params(params_file_som)

    params = pd.Series(parameters.iloc[args.start_idx])
    checkpoint = torch.load(args.model)

    dataset = Datasets(dataset=args.dataset, root_folder=args.root, debug=args.debug, n_samples=args.n_samples,
                       coil20_unprocessed=args.coil20_unprocessed)

    autoencoder_som = AutoEncoderSOM(d_in=dataset.d_in,
                                     hw_in=dataset.hw_in,
                                     som_input=int(params.som_in),
                                     n_max=int(params.n_max),
                                     at=params.at,
                                     ds_beta=params.ds_beta,
                                     lr=params.lr,
                                     eps_ds=params.eps_ds,
                                     ld=params.ld,
                                     gamma=params.gamma,
                                     semi=args.semi,
                                     device=device)

    autoencoder_som.load_state_dict(checkpoint['model_state_dict'], strict=False)
    torch.save({
        'model_state_dict': checkpoint['model_state_dict'],
        'optimizer_state_dict': checkpoint['optimizer_state_dict'],
        'loss': checkpoint['loss'] ,
        'd_in': autoencoder_som.d_in,
        'hw_in': autoencoder_som.hw_in,
        'epochs': params.epochs,
        'input_size': autoencoder_som.som.input_size,
        'n_max': autoencoder_som.som.n_max,
        'at': autoencoder_som.som.at,
        'ds_beta': autoencoder_som.som.ds_beta,
        'lr': autoencoder_som.som.lr,
        'eps_ds': autoencoder_som.som.eps_ds,
        'ld': autoencoder_som.som.ld,
        'gamma': autoencoder_som.som.gamma,
        'seed': params.seed,
        'semi': args.semi,
        'som.node_control' : checkpoint['model_state_dict']['som.node_control'],
        'som.life' : checkpoint['model_state_dict']['som.life']
    }, args.model[0:-4] + "_new.pth")