import numpy as np
import torch as th
from torch import nn

from networks.simple import TimeConder


def get_reg_fns(fns=None):
    from ANONYMOUS.utils import imp

    reg_fns = []
    if fns is None:
        return reg_fns

    for _fn in fns:
        reg_fns.append(imp.load_class(_fn))

    return reg_fns


class RegNN(nn.Module):  # pylint: disable=abstract-method, too-many-instance-attributes
    def __init__(
        self,
        f_func,
        g_func,
        reg_fns,
        f_format="f",
        data_dim=2,
        sde_type="stratonovich",
        noise_type="diagonal",
        nn_clip=1e2,
        lgv_clip=1e2,
    ):  # pylint: disable=too-many-arguments
        super().__init__()
        self.f_func = f_func
        self.g_func = g_func
        self.reg_fns = get_reg_fns(reg_fns)
        self.ndim = np.prod(f_func.data_shape)
        self.nreg = len(reg_fns)
        self.sde_type = sde_type
        self.noise_type = noise_type
        self.nn_clip = nn_clip * 1.0
        self.lgv_clip = lgv_clip * 1.0
        self.dataset = None
        self.data_dim = np.prod(data_dim)

        self.select_f(f_format)

    def select_f(self, f_format=None):
        if f_format == "f":

            def _fn(t, x):
                return th.clip(self.f_func(t, x), -self.nn_clip, self.nn_clip)
        elif f_format == "f_tnet_grad":
            self.lgv_coef = TimeConder(64, 1, 3)

            def _fn(t, x):
                grad = th.clip(self.dataset.lgv_gradient(x), -self.lgv_clip, self.lgv_clip)
                f = th.clip(self.f_func(t, x), -self.nn_clip, self.nn_clip)
                return  f - self.lgv_coef(t) * grad

        elif f_format == "f_tcoef_grad":
            self.lgv_coef = nn.Parameter(
                th.ones((1, self.data_dim)).float(), requires_grad=True
            )

            def _fn(t, x):
                grad = th.clip(self.dataset.lgv_gradient(x), -self.lgv_clip, self.lgv_clip)
                return self.f_func(t, x) - self.lgv_coef * grad

        elif f_format == "lgv":

            def _fn(t, x):  # pylint: disable= unused-argument
                return -th.clip(self.dataset.lgv_gradient(x), -self.lgv_clip, self.lgv_clip) / np.sqrt(2)

        elif f_format == "sigmod_fdotgrad":

            def _fn(t, x):
                grad = th.clip(self.dataset.lgv_gradient(x), -self.lgv_clip, self.lgv_clip)
                return -2 * th.sigmoid(self.f_func(t, x)) * grad

        elif f_format == "fdotgrad":

            def _fn(t, x):
                grad = th.clip(self.dataset.lgv_gradient(x), -self.lgv_clip, self.lgv_clip)
                return -self.f_func(t, x) * grad

        else:
            raise RuntimeError

        self.param_fn = _fn

    def before_int(self, *args, **kwargs):
        self.f_func.before_int(*args, **kwargs)

    def f(self, t, state):
        # t: scaler
        # state: Bx(n_state, n_reg)
        class SharedContext:  # pylint: disable=too-few-public-methods
            pass

        x = th.nan_to_num(state[:, : -self.nreg])
        dx = self.param_fn(t, x)
        dreg = tuple(reg_fn(x, dx, SharedContext) for reg_fn in self.reg_fns)
        return th.cat((dx,) + dreg, dim=1)

    def g(self, t, state):
        origin_g = self.g_func(t, state[:, : -self.nreg])
        return th.cat(
            (origin_g, th.zeros((state.shape[0], self.nreg)).to(origin_g)), dim=1
        )


def quad_reg(x, dx, context):
    del x, context
    dx = dx.view(dx.shape[0], -1)
    return 0.5 * dx.pow(2).sum(dim=-1, keepdim=True)
