import torch
import pandas as pd
from models.autoencoder_som import AutoEncoderSOM
from models.som import SOMUnsupervised, SOMSemiSupervised
import collections


def load_autoencodersom_model(model, device):
    checkpoint = torch.load(model, map_location=device)
    input_size = checkpoint['input_size'] if 'input_size' in checkpoint.keys() else 2
    n_max_som = checkpoint['n_max'] if 'n_max' in checkpoint.keys() else 20
    d_in = checkpoint['d_in'] if 'd_in' in checkpoint.keys() else 1
    hw_in = checkpoint['hw_in'] if 'hw_in' in checkpoint.keys() else 28
    semi = checkpoint['semi'] if 'semi' in checkpoint.keys() else False

    params = load_som_params(checkpoint)

    use_lr = checkpoint['lr'] if 'lr' in checkpoint.keys() else params.lr
    use_lr0 = checkpoint['lr0'] if 'lr0' in checkpoint.keys() else params.lr

    autoencoder_som = AutoEncoderSOM(d_in=d_in,
                                     hw_in=hw_in,
                                     som_input=input_size,
                                     n_max=n_max_som,
                                     at=params.at,
                                     ds_beta=params.ds_beta,
                                     lr0=use_lr0,
                                     lr=use_lr,
                                     lr_push=params.lr_push,
                                     eps_ds=params.eps_ds,
                                     ld=params.ld,
                                     gamma=params.gamma,
                                     semi=semi,
                                     device=device)

    autoencoder_som.load_state_dict((checkpoint['model_state_dict']))
    epochs = checkpoint['epochs'] if 'epochs' in checkpoint.keys() else 5
    manual_seed = checkpoint['seed'] if 'seed' in checkpoint.keys() else 5555
    alpha = checkpoint['alpha'] if 'alpha' in checkpoint.keys() else None
    beta = checkpoint['beta'] if 'beta' in checkpoint.keys() else None
    lr_decay = checkpoint['lr_decay'] if 'lr_decay' in checkpoint.keys() else None

    return autoencoder_som, epochs, manual_seed, alpha, beta, lr_decay


def load_som_model(model, device, semi):
    checkpoint = torch.load(model, map_location=device)
    input_size = checkpoint['input_size'] if 'input_size' in checkpoint.keys() else 2
    n_max_som = checkpoint['n_max'] if 'n_max' in checkpoint.keys() else 20

    params = load_som_params(checkpoint)

    som = choose_som(input_size, n_max_som, params, device, semi)

    som.load_state_dict(checkpoint['model_state_dict'])
    som_epochs = checkpoint['epochs'] if 'epochs' in checkpoint.keys() else 5
    manual_seed = checkpoint['seed'] if 'seed' in checkpoint.keys() else 5555

    return som, som_epochs, manual_seed


def load_som_params(checkpoint):
    Params = collections.namedtuple('Params', ['at', 'ds_beta', 'lr0', 'lr', 'lr_push','eps_ds', 'ld', 'gamma'])
    load_params = Params(at=checkpoint['at']  if 'at' in checkpoint.keys() else 0.985,
                         ds_beta=checkpoint['ds_beta'] if 'ds_beta' in checkpoint.keys() else 0.5,
                         lr0=checkpoint['lr0'] if 'lr0' in checkpoint.keys() else 0.1,
                         lr=checkpoint['lr'] if 'lr' in checkpoint.keys() else 0.1,
                         lr_push=checkpoint['lr_push'] if 'lr_push' in checkpoint.keys() else 1.0,
                         eps_ds=checkpoint['eps_ds'] if 'eps_ds' in checkpoint.keys() else 1.0,
                         ld=checkpoint['ld'] if 'ld' in checkpoint.keys() else 0.05,
                         gamma=checkpoint['gamma'] if 'gamma' in checkpoint.keys() else 3.0)

    return load_params


def choose_som(input_size, n_max, param_set, device, semi):
    if semi:
        som = SOMSemiSupervised(input_dim=input_size,
                                n_max=n_max,
                                at=param_set.at,
                                ds_beta=param_set.ds_beta,
                                lr0=param_set.lr0,
                                lr=param_set.lr,
                                lr_push=param_set.lr_push,
                                eps_ds=param_set.eps_ds,
                                ld=param_set.ld,
                                gamma=param_set.gamma,
                                device=device)
    else:
        som = SOMUnsupervised(input_dim=input_size,
                              n_max=n_max,
                              at=param_set.at,
                              ds_beta=param_set.ds_beta,
                              lr0=param_set.lr0,
                              lr=param_set.lr,
                              lr_push=param_set.lr_push,
                              eps_ds=param_set.eps_ds,
                              ld=param_set.ld,
                              gamma=param_set.gamma,
                              device=device)

    return som

