from __future__ import print_function
import torch
import torch.utils.data
import src.models.modules.nn_modules as NN_modules

# * monge map


def generate_monge_NN(cfg):
    h = NN_modules.Fully_connected(
        input_dim=cfg.INPUT_DIM,
        output_dim=1,
        hidden_dim=cfg.NUM_NEURON_h,
        activation=cfg.nn_activation,
        num_layer=cfg.NUM_LAYERS_h,
        final_actv=cfg.final_actv,
        quadr=cfg.quadr)
    if cfg.map_type == 'nabla_g':
        #!tmp
        convex_g = NN_modules.ICNN_LastInp_Quadratic(
            cfg.INPUT_DIM,
            cfg.NUM_NEURON_map,
            cfg.g_activation,
            cfg.NUM_LAYERS_g,
            dropout=cfg.dropout)
        # convex_g = NN_modules.DenseICNN(
        #     dim=cfg.INPUT_DIM, hidden_layer_sizes=[
        #         cfg.NUM_NEURON_map] * cfg.NUM_LAYERS_g, rank=5, activation='softplus')
        return h, convex_g
    elif cfg.map_type == 'T':
        if cfg.T_linear:
            map_t = NN_modules.FC_linear(
                input_dim=cfg.INPUT_DIM,
                output_dim=cfg.INPUT_DIM,
                hidden_dim=cfg.NUM_NEURON_map,
                num_layer=cfg.NUM_LAYERS_g,
                res=cfg.T_res)
        else:
            map_t = NN_modules.Fully_connected(
                input_dim=cfg.INPUT_DIM,
                output_dim=cfg.INPUT_DIM,
                hidden_dim=cfg.NUM_NEURON_map,
                activation=cfg.g_activation,
                num_layer=cfg.NUM_LAYERS_g,
                full_activ=cfg.full_actv,
                batch_nml=cfg.batch_nml,
                dropout=cfg.dropout,
                res=cfg.T_res,
                quadr=0)
        return h, map_t

# * fully connected


# def generate_fully_connected(cfg):
#     return NN_modules.Fully_connected(cfg.INPUT_DIM, cfg.OUTPUT_DIM, cfg.NUM_NEURON_map, cfg.NUM_LAYERS, cfg.activation, cfg.final_actv)

#! load


def load_generator(results_save_path, generator_g, epochs, choice='g', device=None):
    model_save_path = results_save_path + '/storing_models'
    try:
        generator_g.load_state_dict(torch.load(
            model_save_path + f'/{choice}_epoch{epochs}.pt', map_location=device))
    except:
        print("no file for network")
        print(model_save_path)
    return generator_g.cuda(device)
