"""Contains the NNDE solver model. Handles multiple orders of error correction.
"""
import torch
import numpy as np
import torch.nn as nn

from functools import partial
from collections import OrderedDict

### Taken from Stizmann et al., 2020 ###

def sine_init(m, **kwargs):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            scale = kwargs.get('scale', 1)
            num_input = m.weight.size(-1)
            m.weight.uniform_(-np.sqrt(6 / num_input) / scale, np.sqrt(6 / num_input) / scale)


def first_layer_sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            m.weight.uniform_(-1 / num_input, 1 / num_input)

########################################


class Sin(nn.Module):

    def __init__(self, scale=1):
        super().__init__()
        self.scale = scale

    def forward(self, x):  
        return torch.sin(self.scale * x)


class LinearBlock(nn.Module):

    def __init__(self, dim_in, dim_out, act, **kwargs):
        super().__init__()

        self.activation = {
            "relu": nn.ReLU,
            "sinusoid": Sin,
            "tanh": nn.Tanh
        }[act]

        self.layers = nn.Sequential(
            nn.Linear(dim_in, dim_out),
            self.activation(**kwargs),
        )
    
    def forward(self,x):
        return self.layers(x)


class MLP(nn.Module):

    def __init__(self, hparams, **kwargs):
        super().__init__()
        self.dim_in = hparams.model["dim"]["input"]
        self.dim_hid = hparams.model["dim"]["hidden"]
        self.dim_out = hparams.model["dim"]["output"]
        self.depth = hparams.model["depth"]
        self.act = hparams.model["activation"]
        self.layers = self.create_layers(**kwargs)

    def create_layers(self, **kwargs):
        if self.depth == 0:
            layers = [nn.Linear(self.dim_in, self.dim_out)]
        else:
            layers = [
                LinearBlock(self.dim_hid, self.dim_hid, self.act, **kwargs) 
                for _ in range(self.depth)
            ]
            layers.insert(0, LinearBlock(self.dim_in, self.dim_hid, self.act, **kwargs))        
            layers.append(nn.Linear(self.dim_hid, self.dim_out))

        return nn.Sequential(*layers)

    def forward(self, x):
        # print(x.shape)
        return self.layers(x).squeeze()


class NNDESolver(nn.Module):
    """
    Handle estimating the solution to a DiffEq. Can handle multiple orders of error correction

    NOTE: Assume exact same architecture as estimate model
    NOTE: Model I/O is handled internally for ease of use
    """
    def __init__(self, hparams, reparam_fn=None, orders=0):
        super().__init__()
        self.hparams = hparams
        self.orders = orders

        # function for reparameterizing the model outputs
        # this happens in a diffeq specific way, so this must be
        # defined in the DifferentialEquation class if it is non-trivial
        self.reparameterize = reparam_fn or (lambda x, _, error=False: x)

        # for debugging
        self.frozen = {"N": {"activated": True, "frozen": False}} 
        self.frozen.update(
            {f"N_e{i+1}": {"activated": False, "frozen": False} for i in range(orders)}
        )

        # determines how many orders of error estimation to compute
        self.order = 0

        # base solver/estimator
        self.N = MLP(hparams) 

        # magnitudes & frequencies to scale & modulate error model outputs
        # only applicable to sinusoidal activation networks
        self.mags = [hparams.model.get("error_magnitude", 1)] * orders
        self.freqs = [hparams.model.get("error_frequency", 1)] * orders

        # initialize error estimators
        self.error_models = nn.ModuleDict([
            (f"N_e{i+1}", MLP(hparams)) if hparams.model.activation != 'sinusoid'
            else (f"N_e{i+1}", MLP(hparams, scale=freq))
            for i, freq in zip(range(orders), self.freqs)
        ])

        # final activation
        if hparams.model.final_act is not None:
            self.final_act = getattr(torch, hparams.model.final_act)

        # initialize sinusoidal activation error models
        if hparams.model.activation == 'sinusoid':
            for model, freq in zip(self.error_models, self.freqs):
                self.error_models[model].layers[0].apply(first_layer_sine_init)
                self.error_models[model].apply(partial(sine_init, scale=freq))

    def freeze_model(self, model, unfreeze=False):
        """Change gradient requirements of model
        """
        for param in model.parameters():
            param.requires_grad = unfreeze

    def load_model(self, order, model_path, logger=None):
        """Loads sub model parameters of specific order
        """
        # assuming all parameters are on the same device
        model = self.N if order == 0 else self.error_models[f"N_e{order}"]
        model_device = next(model.parameters()).device

        state_dict = torch.load(model_path, map_location=model_device)
        model.load_state_dict(state_dict)

        if logger is not None:
            logger.info(f"Loaded order-{order} model {model_path}")

    def save_model(self, save_dir, model_name, logger=None):
        """Logic for saving the base solver and its error corrections
        """
        if self.order == 0:  # save the base model
            model_id = "N"
            base_model = self.N.module if hasattr(self.N, "module") else self.N
            model_path = f"{save_dir}/{model_name}"
            torch.save(base_model.state_dict(), model_path)
        else:  # save the error model of this order
            model_id = f"N_e{self.order}"
            Ne = self.error_models[model_id]
            error_model = Ne.module if hasattr(Ne, "module") else Ne
            model_path = f"{save_dir}/order-{self.order}_{model_name}"
            torch.save(error_model.state_dict(), model_path)

        if logger is not None:
            logger.info(f"Saved {model_id}: {model_path}")

    def set_order(self, order):
        """Freeze all models from lower orders only if order > 0

        NOTE: MUST be called in order to perform error correction at all
        """
        assert (order > 0) and (order <= self.orders), f"Invalid order of correction: {order}"
        self.order = order

        # freeze previous error estimators
        self.freeze_model(self.N)
        self.frozen['N']['frozen'] = True
        for i in range(order-1):
            self.freeze_model(self.error_models[f"N_e{i+1}"])
            self.frozen[f"N_e{i+1}"]['frozen'] = True

        # activate gradients for the current order error model
        self.freeze_model(self.error_models[f"N_e{order}"], unfreeze=True)
        self.frozen[f"N_e{order}"]['frozen'] = False
        self.frozen[f"N_e{order}"]['activated'] = True

    def forward_models(self, x):
        """Return forward pass outputs from main model and error models
        """
        x.requires_grad_()
        prediction = self.reparameterize(self.N(x), x)
        error_estimates = [
            self.mags[i]*self.reparameterize(self.error_models[f"N_e{i+1}"](x), x, error=True)
            for i in range(self.order)
        ]
        error_estimates.insert(0, prediction)
        return error_estimates

    def forward(self, x):
        """Forward pass thru all models and sum + [optional] final activation
        """
        model_outputs = self.forward_models(x)
        full_prediction = sum(model_outputs)
        if self.hparams.model.final_act is not None:
            full_prediction = self.final_act(full_prediction)
        return full_prediction
