import torch as tc
import torch.nn as nn
from dataclasses import dataclass
from typing import override, Any
from src.utils import parse_activ_f_TC, parse_activ_f, SamplingArgs, load_linear_params_from_numpy
from src.swimnetworks.swimnetworks import Dense

@dataclass(eq=False)
class FCNN(nn.Module):
    dof: int # only required for experiments, not used in the model explicitly
    n_obj: list[int] # only required for experiments agian
    n_features: int
    width: int = 128
    activ_str: str = "softplus"
    init_method: str = "relu"
    seed: int = 128347
    dtype: Any = tc.float32

    def __post_init__(self):
        super(FCNN, self).__init__()
        tc.manual_seed(self.seed)
        tc.cuda.manual_seed_all(self.seed)
        self.dense = nn.Linear(self.n_features, self.width, dtype=self.dtype)
        self.activ = parse_activ_f_TC(self.activ_str)
        self.linear = nn.Linear(self.width, 1, dtype=self.dtype) # always scalar output for the Hamiltonian ("total energy")

    @override
    def forward(self, x, apply_linear=True):
        """
        Args:
            x                       of shape (n_points, n_features)
            apply_linear            whether to apply the final linear layer
        """
        x = self.dense(x)           # of shape (n_points, width)
        x = self.activ(x)           # of shape (n_points, width)
        if apply_linear:
            x = self.linear(x)      # of shape (n_points, 1)
        return x

    def sample_hidden(self, x: tc.Tensor, sampling_args: SamplingArgs,
                      e_pred: tc.Tensor | None = None):
        """
        Samples hidden layer parameters using SWIM algorithm.

        Args:
            x                       of shape (n_points, n_features)
            e_pred                  corresponding Hamiltonian values of shape (n_points, 1) for SWIM algorithm.
                                    if set to None then we fallback to uniform sampling in the data space.
            sampling_args           SamplingArgs
        """
        dense = Dense(layer_width=self.width, activation=parse_activ_f(self.activ_str), activ_str=self.activ_str,
                      parameter_sampler=sampling_args.param_sampler, random_seed=sampling_args.seed,
                      sample_uniformly=sampling_args.sample_uniformly, is_classifier=False,
                      prune_duplicates=False, resample_duplicates=sampling_args.resample_duplicates,
                      swim_dy_norm_ord=sampling_args.swim_dy_norm_ord,
                      elm_weight_loc=sampling_args.elm_weight_loc, elm_weight_std=sampling_args.elm_weight_std,
                      elm_bias_start=sampling_args.elm_bias_start, elm_bias_end=sampling_args.elm_bias_end,
                      dtype=sampling_args.dtype)
        if not sampling_args.sample_uniformly and e_pred is None:
            print("***(Approximate)-SWIM INITIAL-APPROXIMATION***")
            dense.sample_uniformly = True # for the initial fit of Approximate-SWIM, temporary only
            dense.fit(x.detach().numpy())
            dense.sample_uniformly = False
        elif not sampling_args.sample_uniformly and e_pred is not None:
            print("***(Approximate)-SWIM RESAMPLING***")
            dense.fit(x.detach().numpy(), e_pred.detach().numpy())
        elif sampling_args.sample_uniformly and e_pred is None:
            # print("***Sampling uniformly***")
            dense.fit(x.detach().numpy())
        else:
            raise ValueError("y values are provided but you specified sample_uniformly=True")
        self.load_params({ "dense_weight": dense.weights, "dense_bias": dense.biases })

    def load_params(self, params):
        """
        Loads weights and biases defined as numpy ndarrays
        Args:
            params      dict of parameters of 'dense' and 'linear' layers
        """
        load_linear_params_from_numpy(self.dense, params["dense_weight"], params["dense_bias"])
        if "linear_weight" in params and "linear_bias" in params:
            load_linear_params_from_numpy(self.linear, params["linear_weight"], params["linear_bias"])

    def init_params(self, method="xavier_normal"):
        """
        Initializes affine layer weights and biases according to the given initialization method.
        Note: This is not sampling initialization, for sampling refer to sample_hidden

        Args:
            method      'kaiming_normal', 'kaiming_uniform', 'xavier_normal', 'xavier_uniform', or 'orthogonal'
        """
        match method:
            case "kaiming_normal":
                init_fun = nn.init.kaiming_normal_
            case "kaiming_uniform":
                init_fun = nn.init.kaiming_uniform_
            case "xavier_normal":
                init_fun = nn.init.xavier_normal_
            case "xavier_uniform":
                init_fun = nn.init.xavier_uniform_
            case "orthogonal":
                init_fun = nn.init.orthogonal_
            case _:
                raise NotImplementedError("unknown weight initialization")

        init_fun(self.dense.weight)
        init_fun(self.linear.weight)

        match self.init_method:
            case "tanh":
                nn.init.zeros_(self.dense.bias)
                nn.init.zeros_(self.linear.bias)
            case "relu":
                nn.init.constant_(self.dense.bias, 0.01)
                nn.init.constant_(self.linear.bias, 0.01)
            case _:
                raise NotImplementedError("Unknown init_method for the learnable parameters")
