import math

import numpy as np
import torch
from torch.nn import ModuleList
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F

from relnet.agent.pytorch_agent import PyTorchAgent
from relnet.agent.gnn.feature_processing import to_pyg_datapoint_demand, to_pyt_datapoint, get_mlp_input_dim
from relnet.agent.gnn.demand_reg_gnn import DemandRegressorGNN
from relnet.state.graph_state import GraphState
from relnet.utils.config_utils import get_device_placement


class MLPAgent(PyTorchAgent):
    algorithm_name = "mlp"

    is_deterministic = False
    is_trainable = True

    def setup_predictor(self):
        datapoint_dim = get_mlp_input_dim(self.demand_rep, self.hyperparams, self.gds_metadata)
        self.net = MLPRegressor(datapoint_dim, self.hyperparams)
        print("set up MLP regressor!")
        if get_device_placement() == 'GPU':
            self.net = self.net.cuda()

    def predict(self, g_list, **kwargs):
        if self.restore_model:
            self.setup_predictor()
            self.restore_model_from_checkpoint()

        preds = []
        bs = self.hyperparams['batch_size']
        loader = self.load_data(g_list, bs)

        for data in loader:
            batch_preds = self.net(data, **kwargs)
            preds.extend(batch_preds)

        return torch.stack(preds).flatten()

    def load_data(self, g_list, bs):
        datapoint_dim = get_mlp_input_dim(self.demand_rep, self.hyperparams, self.gds_metadata)
        ds = torch.zeros(len(g_list), datapoint_dim)

        for i, g_hash in enumerate(g_list):
            if g_hash in self.pyt_cache:
                datapoint = self.pyt_cache[g_hash]
            else:
                graph = self.graph_ds.load_graph_file(g_hash)
                orig_graph = self.graph_ds.load_graph_file(self.gds_metadata['original_graph_hash'])
                datapoint = to_pyt_datapoint(graph, self.demand_rep, self.hyperparams, self.gds_metadata, orig_graph=orig_graph)
                del graph
                del orig_graph
                self.pyt_cache[g_hash] = datapoint

            ds[i, :] = datapoint

        return DataLoader(TensorDataset(ds), batch_size=bs, shuffle=False)



    def setup(self, options, hyperparams):
        super().setup(options, hyperparams)
        self.pyt_cache = {}

    @staticmethod
    def get_default_hyperparameters():
        mlp_params = {
            "first_hidden_size": 1024,
            "learning_rate": 0.01,
        }
        return mlp_params

    def post_env_setup(self):
        pass

    def finalize(self):
        pass

class MLPRegressor(torch.nn.Module):
    def __init__(self, input_dim, hyperparams):
        super().__init__()
        self.input_dim = input_dim
        self.hyperparams = hyperparams

        self.layers = ModuleList()
        hidden_layer_dim = self.hyperparams['first_hidden_size']
        self.layers.append(torch.nn.Linear(self.input_dim, hidden_layer_dim))

        while True:
            next_dim = math.ceil(hidden_layer_dim / 2)
            self.layers.append(torch.nn.Linear(hidden_layer_dim, next_dim))

            hidden_layer_dim = next_dim
            if hidden_layer_dim < 16:
                break

        self.layers.append(torch.nn.Linear(hidden_layer_dim, 1))


    def forward(self, data):
        x = data[0]

        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i < len(self.layers) - 1:
                x = F.relu(x)

        return x


class SumMLPAgent(MLPAgent):
    algorithm_name = "mlp_summed_default"
    demand_rep = "summed"


class RawMLPAgent(MLPAgent):
    algorithm_name = "mlp_raw_default"
    demand_rep = "raw"
