import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Union
import torch.nn.functional as F
class SmODECell(nn.Module):
    def __init__(
            self,
            wiring,
            in_features=None,
            input_mapping="affine",
            output_mapping="affine",
            ode_unfolds=6,
            epsilon=1e-8,
            implicit_param_constraints=False,
            **kwargs
    ):
        super(SmODECell, self).__init__()
        if in_features is not None:
            wiring.build(in_features)
        if not wiring.is_built():
            raise ValueError(
                "Wiring error! Unknown number of input features. Please pass the parameter 'in_features' or call the 'wiring.build()'."
            )
        self.make_positive_fn = (
            nn.Softplus() if implicit_param_constraints else nn.Identity()
        )
        self._implicit_param_constraints = implicit_param_constraints
        self._init_ranges = {
            "gleak": (0.001, 1.0),
            "vleak": (-0.2, 0.2),
            "cm": (0.4, 0.6),
            "w": (0.001, 1.0),
            "sigma": (3, 8),
            "mu": (0.3, 0.8),
            "sensory_w": (0.001, 1.0),
            "sensory_sigma": (3, 8),
            "sensory_mu": (0.3, 0.8),
        }
        self.func_h = nn.Linear(in_features, wiring.units)
        self.tanh = nn.Tanh()
        self.in_features = in_features
        self.out_features = wiring.units
        self.hidden_state = None
        self.w_activation = None
        self.para_loss = 0.
        self.lambda1 = kwargs["lambda1"]
        self.lambda2 = kwargs["lambda2"]
        self._wiring = wiring
        self._input_mapping = input_mapping
        self._output_mapping = output_mapping
        self._ode_unfolds = ode_unfolds
        self._epsilon = epsilon
        self._clip = torch.nn.ReLU()
        self._allocate_parameters()
    @property
    def state_size(self):
        return self._wiring.units
    @property
    def sensory_size(self):
        return self._wiring.input_dim
    @property
    def motor_size(self):
        return self._wiring.output_dim
    @property
    def output_size(self):
        return self.motor_size
    @property
    def synapse_count(self):
        return np.sum(np.abs(self._wiring.adjacency_matrix))
    @property
    def sensory_synapse_count(self):
        return np.sum(np.abs(self._wiring.adjacency_matrix))
    def add_weight(self, name, init_value, requires_grad=True):
        param = torch.nn.Parameter(init_value, requires_grad=requires_grad)
        self.register_parameter(name, param)
        return param
    def _get_init_value(self, shape, param_name):
        minval, maxval = self._init_ranges[param_name]
        if minval == maxval:
            return torch.ones(shape) * minval
        else:
            return torch.rand(*shape) * (maxval - minval) + minval
    def _allocate_parameters(self):
        print("alloc!")
        self._params = {}
        self._params["gleak"] = self.add_weight(
            name="gleak", init_value=self._get_init_value((self.state_size,), "gleak")
        )
        self._params["vleak"] = self.add_weight(
            name="vleak", init_value=self._get_init_value((self.state_size,), "vleak")
        )
        self._params["cm"] = self.add_weight(
            name="cm", init_value=self._get_init_value((self.state_size,), "cm")
        )
        self._params["sigma"] = self.add_weight(
            name="sigma",
            init_value=self._get_init_value(
                (self.state_size, self.state_size), "sigma"
            ),
        )
        self._params["mu"] = self.add_weight(
            name="mu",
            init_value=self._get_init_value((self.state_size, self.state_size), "mu"),
        )
        self._params["w"] = self.add_weight(
            name="w",
            init_value=self._get_init_value((self.state_size, self.state_size), "w"),
        )
        self._params["erev"] = self.add_weight(
            name="erev",
            init_value=torch.Tensor(self._wiring.erev_initializer()),
        )
        self._params["sensory_sigma"] = self.add_weight(
            name="sensory_sigma",
            init_value=self._get_init_value(
                (self.sensory_size, self.state_size), "sensory_sigma"
            ),
        )
        self._params["sensory_mu"] = self.add_weight(
            name="sensory_mu",
            init_value=self._get_init_value(
                (self.sensory_size, self.state_size), "sensory_mu"
            ),
        )
        self._params["sensory_w"] = self.add_weight(
            name="sensory_w",
            init_value=self._get_init_value(
                (self.sensory_size, self.state_size), "sensory_w"
            ),
        )
        self._params["sensory_erev"] = self.add_weight(
            name="sensory_erev",
            init_value=torch.Tensor(self._wiring.sensory_erev_initializer()),
        )
        self._params["sparsity_mask"] = self.add_weight(
            "sparsity_mask",
            torch.Tensor(np.abs(self._wiring.adjacency_matrix)),
            requires_grad=False,
        )
        self._params["sensory_sparsity_mask"] = self.add_weight(
            "sensory_sparsity_mask",
            torch.Tensor(np.abs(self._wiring.sensory_adjacency_matrix)),
            requires_grad=False,
        )
        if self._input_mapping in ["affine", "linear"]:
            self._params["input_w"] = self.add_weight(
                name="input_w",
                init_value=torch.ones((self.sensory_size,)),
            )
        if self._input_mapping == "affine":
            self._params["input_b"] = self.add_weight(
                name="input_b",
                init_value=torch.zeros((self.sensory_size,)),
            )
        if self._output_mapping in ["affine", "linear"]:
            self._params["output_w"] = self.add_weight(
                name="output_w",
                init_value=torch.ones((self.motor_size,)),
            )
        if self._output_mapping == "affine":
            self._params["output_b"] = self.add_weight(
                name="output_b",
                init_value=torch.zeros((self.motor_size,)),
            )
    def _sigmoid(self, v_pre, mu, sigma):
        v_pre = torch.unsqueeze(v_pre, -1)  
        mues = v_pre - mu
        x = sigma * mues
        return torch.sigmoid(x)
    def _ode_solver(self, inputs, state, elapsed_time):
        v_pre = state
        self.hidden_state = self.func_h(inputs)
        self.hidden_state = self.tanh(self.hidden_state)
        self.hidden_state = torch.unsqueeze(self.hidden_state, dim=1)
        hidden_state_1 = self.hidden_state.repeat(1, self.in_features, 1)
        hidden_state_2 = self.hidden_state.repeat(1, self.out_features, 1)
        sensory_w_activation = self.make_positive_fn(
            self._params["sensory_w"]
        ).clamp_(0.001, 1.0) * self._sigmoid(
            inputs, self._params["sensory_mu"], self._params["sensory_sigma"]
        )
        sensory_w_activation = (
                sensory_w_activation * torch.abs(hidden_state_1)
        )
        sensory_rev_activation = sensory_w_activation * hidden_state_1
        w_numerator_sensory = torch.sum(sensory_rev_activation, dim=1)
        w_denominator_sensory = torch.sum(sensory_w_activation, dim=1)
        cm_t = self.make_positive_fn(self._params["cm"]).clamp_(0.4, 0.6) / (
                elapsed_time / self._ode_unfolds
        )
        w_param = self.make_positive_fn(self._params["w"]).clamp_(0.001, 1.0)
        for t in range(self._ode_unfolds):
            self.w_activation = w_param * self._sigmoid(
                v_pre, self._params["mu"], self._params["sigma"]
            )
            w_activation = self.w_activation * torch.abs(hidden_state_2)
            rev_activation = w_activation * hidden_state_2
            w_numerator = torch.sum(rev_activation, dim=1) + w_numerator_sensory
            w_denominator = torch.sum(w_activation, dim=1) + w_denominator_sensory
            numerator = cm_t * v_pre + self.make_positive_fn(self._params["cm"]).clamp_(0.4, 0.6) * self._params["vleak"] + w_numerator
            denominator = cm_t + self.make_positive_fn(self._params["cm"]).clamp_(0.4, 0.6) + w_denominator
            v_pre = numerator / (denominator + self._epsilon)
        if self.training and torch.is_grad_enabled():
            self.para_loss += self.lambda1 * (torch.mean(self.hidden_state ** 2))
            self.para_loss += self.lambda2 * (torch.mean(w_denominator / self.make_positive_fn(self._params["cm"]).clamp_(0.4, 0.6)))
        return v_pre
    def _map_inputs(self, inputs):
        if self._input_mapping in ["affine", "linear"]:
            inputs = inputs * self._params["input_w"]
        if self._input_mapping == "affine":
            inputs = inputs + self._params["input_b"]
        return inputs
    def _map_outputs(self, state):
        output = state
        if self.motor_size < self.state_size:
            output = output[:, 0: self.motor_size]  
        if self._output_mapping in ["affine", "linear"]:
            sigma = self.spectral_norm(self._params["output_w"])
            normalized_weight = self._params["output_w"] / sigma
            output = output * normalized_weight
        if self._output_mapping == "affine":
            output = output + self._params["output_b"]
        return output
    def spectral_norm(self, weight_matrix, n_iter=1):
        w_shape = weight_matrix.shape
        weight_matrix = weight_matrix.view(w_shape[0], -1)
        u = torch.randn(weight_matrix.size(0), 1)
        u = F.normalize(u, dim=0, eps=1e-12)
        for _ in range(n_iter):
            v = F.normalize(torch.matmul(weight_matrix.t(), u), dim=0, eps=1e-12)
            u = F.normalize(torch.matmul(weight_matrix, v), dim=0, eps=1e-12)
        sigma = torch.dot(u.squeeze(), torch.matmul(weight_matrix, v).squeeze())
        return sigma
    def apply_weight_constraints(self):
        if not self._implicit_param_constraints:
            self._params["w"].data = self._clip(self._params["w"].data)
            self._params["sensory_w"].data = self._clip(self._params["sensory_w"].data)
            self._params["cm"].data = self._clip(self._params["cm"].data)
            self._params["gleak"].data = self._clip(self._params["gleak"].data)
    def forward(self, inputs, states, elapsed_time=1.0, sample=False):
        self.para_loss = 0.
        inputs = self._map_inputs(inputs)
        next_state = self._ode_solver(inputs, states, elapsed_time)
        outputs = self._map_outputs(next_state)
        return outputs, next_state, self.para_loss
