from collections import OrderedDict
from models.RNN import RNNModel
from models.AR import AutoRegressive
from models.LSTM import LSTMModel
from models import LSTNet
from recon.learning_moe import fitModel
from recon.learning_moe import Gating_Network, Qunatile_Network
import torch
import logging

p = 5
logger = logging.getLogger('MECATS.Models')


def get_models_optimizers(node_list, algs, cuda, params, sharq_params, Data):
    assert sharq_params.RECON == 'sharq'
    models, quantile_models, optimizers = OrderedDict(), OrderedDict(), OrderedDict()
    quantile_optimizers, combined_optimizers = OrderedDict(), OrderedDict()

    for name in node_list:
        if algs == 'rnn':
            model_dict = [RNNModel(input_dim=1, hidden_dim=params.hidden_dim, layer_dim=params.layer_dim, quantiles=[sharq_params.mid_quantile],
                                nonlinearity=params.nonlinearity),
                        RNNModel(input_dim=1, hidden_dim=params.hidden_dim, layer_dim=params.layer_dim, quantiles=[sharq_params.lower_quantile, sharq_params.upper_quantile],
                                nonlinearity=params.nonlinearity)]
            optimizer_dict = [torch.optim.SGD(model_dict[0].parameters(), lr=params.learning_rate),
                            torch.optim.SGD(model_dict[1].parameters(), lr=params.learning_rate),
                            torch.optim.Adam(model_dict[1].parameters(), lr=params.learning_rate)]
        elif algs == 'lstm':
            model_dict = [LSTMModel(input_dim=1, hidden_dim=params.hidden_dim, layer_dim=params.layer_dim, quantiles=[sharq_params.mid_quantile]),
                            LSTMModel(input_dim=1, hidden_dim=params.hidden_dim, layer_dim=params.layer_dim, quantiles=[sharq_params.lower_quantile, sharq_params.upper_quantile])]
            optimizer_dict = [torch.optim.SGD(model_dict[0].parameters(), lr=params.learning_rate),
                            torch.optim.SGD(model_dict[1].parameters(), lr=params.learning_rate),
                            torch.optim.Adam(model_dict[1].parameters(), lr=params.learning_rate)]
        elif algs == 'ar':
            model_dict = [AutoRegressive(quantiles=[sharq_params.mid_quantile], p=p), AutoRegressive(quantiles=[sharq_params.lower_quantile, sharq_params.upper_quantile], p=p)]
            optimizer_dict = [torch.optim.SGD(model_dict[0].parameters(), lr=params.learning_rate),
                            torch.optim.SGD(model_dict[1].parameters(), lr=params.learning_rate),
                            torch.optim.Adam(model_dict[1].parameters(), lr=params.learning_rate)]
        elif algs == 'lstnet':
            model_dict = [LSTNet.Model(Data, sharq_params.window, method='sharq', quantiles=[sharq_params.mid_quantile]), LSTNet.Model(Data, sharq_params.window, method='sharq', quantiles=sharq_params.other_quantiles)]
            optimizer_dict = [torch.optim.Adam(model_dict[0].parameters(), lr=params.learning_rate),
                            torch.optim.Adam(model_dict[1].parameters(), lr=params.learning_rate),
                            torch.optim.Adam(model_dict[1].parameters(), lr=params.learning_rate)]
        model, quantile_model = model_dict[0], model_dict[1]
        if cuda:
            models[name], quantile_models[name] = model.cuda(), quantile_model.cuda()
        else:
            models[name], quantile_models[name] = model, quantile_model
        optimizers[name] = optimizer_dict[0]
        quantile_optimizers[name] = optimizer_dict[1]
        combined_optimizers[name] = optimizer_dict[2]
    return models, quantile_models, optimizers, quantile_optimizers, combined_optimizers, p
    

def get_moe_optimizers(data, params, node_list, date, dataset, quantile, cuda):
    '''
    Return dictionary of offline-trained MoE models, gating networks, as well as optimizers for the gating network.
    '''
    models, gns, optimizers = OrderedDict(), OrderedDict(), OrderedDict()
    if quantile:
        q_nets, q_optimizers = OrderedDict(), OrderedDict()
    for name in node_list:
        logger.info('Preparing model for node {}'.format(name))
        gating_params = params['gating_network']
        if len(data[name].shape) == 1:
            input_dim = 1
        else:
            input_dim = data[name].shape[1]

        model = fitModel(data[name], date, params, name, dataset, cuda)
        model_dict = model.pretrain()
        gn = Gating_Network(input_dim, gating_params)
        gns[name] = gn
        models[name] = model_dict
        optimizers[name] = torch.optim.Adam(gn.parameters(), lr=gating_params.learning_rate)
        if quantile:
            q_net_params = params['quantile_net']
            if cuda:
                net = Qunatile_Network(q_net_params).to('cuda:{}'.format(torch.cuda.current_device()))
            else:
                net = Qunatile_Network(q_net_params)
            Optim = torch.optim.Adam(net.parameters(), lr=q_net_params.learning_rate)
            q_nets[name], q_optimizers[name] = net, Optim
    if quantile:
        return models, gns, optimizers, q_nets, q_optimizers
    return models, gns, optimizers