import torch as tc
import numpy as np
import torch.nn as nn
from math import sqrt
from copy import deepcopy
from dataclasses import dataclass
from typing import override, Any
from src.utils import parse_activ_f_TC, parse_activ_f, SamplingArgs, load_linear_params_from_numpy, unflatten_TC, memusage
from src.swimnetworks.swimnetworks import Dense
from src.model.modules import Normalizer, LocalPooler

@dataclass(eq=False)
class GNN(nn.Module):
    n_obj: list[int]                    # Number of nodes along dimensions
    dof: int                            # Degrees of freedom
    edge_index: tc.Tensor | list        # Edge index of shape (2, n_edges) (static for all graphs) or list[2, n_edges] with entries for each graph (dynamic)

    # normalizes (adds translation and rotation invariance) to the input
    make_translation_invariant: bool = True
    make_rotation_invariant: bool = True
    direct: bool = False        # only aggregate towards lower index in the local pooling (message passing) step
    enc_width: int = 16
    msg_width: int = 128
    local_pooling: str = "sum"
    global_pooling: str = "sum"
    activ_str: str = "softplus"
    init_method: str = "relu"               # weight initialization for gradient-descent based training
    seed: int = -1                          # for reproducability set
    rcond: float = 1e-10
    dtype: Any = tc.float32
    take_absolute_diff: bool = False
    n_msg_passes: int = 1

    def __post_init__(self):
        super(GNN, self).__init__()
        tc.manual_seed(self.seed)
        tc.cuda.manual_seed_all(self.seed)

        self.node_features_dim = 2 * self.dof    # [q_bar, p_bar]
        self.edge_features_dim = self.dof + 1    # [delta_q_bar, delta_q_bar_norm]

        # scalar (needed for fourier features)
        self.encoder_scaler = 1.0
        self.message_scaler = 1.0

        # activation
        self.activ = parse_activ_f_TC(self.activ_str)

        # normalizer (to make input translation and rotation invariant)
        self.normalizer = Normalizer(self.dof, self.make_translation_invariant, self.make_rotation_invariant)

        # to transform node_features
        self.node_encoder = nn.Linear(self.node_features_dim, self.enc_width, dtype=self.dtype)

        # to transform edge_features
        self.edge_encoder = nn.Linear(self.edge_features_dim, self.enc_width, dtype=self.dtype)

        # to encoder the message
        self.msg_encoder = nn.Linear(2 * self.enc_width, self.msg_width, dtype=self.dtype)
        self.width = self.enc_width + self.msg_width # concat node_enc with the message

        # to locally pool the message
        self.local_pooler = LocalPooler(direct=self.direct, reduce=self.local_pooling)

        # linear map
        self.linear = nn.Linear(self.width, 1, dtype=self.dtype)

    @override
    def forward(self, x, apply_linear=True):
        """
        Args:
            x                       of shape (n_samples, n_features)
            apply_linear            whether to apply the final linear layer
        """
        def _forward(x):
            # if self.training: print("== Forward pass"); memusage()
            # Step 1: prepare input
            x = unflatten_TC(x, self.n_obj, self.dof)
            n_samples = len(x)

            # Extract q and p
            q, p = x[..., :self.dof], x[..., self.dof:]  # (n_samples, *n_obj, dof)

            # Get translation and rotation invariant representations
            q, p = self.normalizer(q, p)
            # if self.training: print("   Get translation and rotation invariant representations q and p"); memusage()

            # Step 2: Prepare node (V) and node encodings (lifted representation)
            node_features = self.get_node_features(q, p)                    # of shape (n_samples * n_nodes, 2*dof)
            node_enc = self.encoder_scaler * self.activ(self.node_encoder(node_features))
            node_enc = node_enc.view(n_samples, *self.n_obj, self.enc_width)

            # Prepare edge features and message logic only if there is an edge in the graph
            if len(self.edge_index) != 0:
                n_edges = self.edge_index.shape[1]
                n_edges_total = self.edge_index.shape[1] if self.direct else 2*self.edge_index.shape[1]

                edge_features = self.get_edge_features(q)                       # of shape (n_samples* n_edges, dof+1)
                # if self.training: print("   Prepare node and edge features"); memusage()

                # Step 3: Encode edge features
                edge_enc = self.encoder_scaler * self.activ(self.edge_encoder(edge_features))
                edge_enc = edge_enc.view(n_samples, n_edges, self.enc_width)
                # if self.training: print("   Node and edge features"); memusage()

                # Step 4: Constructed aggregated features (message) with message passing
                msg = self.locally_pool(node_enc, edge_enc)
                # Step 5: Encode the message
                msg_enc = self.message_scaler * self.activ(self.msg_encoder(msg))     # of shape (n_samples * n_edges_total, msg_width)
                msg_enc = msg_enc.view(n_samples, n_edges_total, self.msg_width)
                received_msg = self.local_pooler(self.n_obj, msg_enc, self.edge_index)  # of shape (n_samples, *n_obj, msg_width)
            else:
                received_msg = tc.zeros((n_samples, *self.n_obj, self.msg_width), device=x.device, dtype=x.dtype)

            # Step 6: Build concatenated final state
            pooled_node_feature = tc.cat([node_enc, received_msg], dim=-1) # of shape (n_samples, *n_obj, width)
            # print("pooled_node_feature.shape", pooled_node_feature.shape)
            # if self.training: print("   Local pooling encoding"); memusage()

            # Step 7: Global pooling to reduce pooled_node_features into a global graph_feature
            global_reduce_dims = list(range(1, pooled_node_feature.ndim - 1))
            graph_feature = self.global_pool(pooled_node_feature, global_reduce_dims, self.global_pooling) # of shape (n_samples, width)
            # if self.training: print("   Global pooling encoding"); memusage()

            if apply_linear:
                # Step 8: Linear map the graph feature into a scalar graph feature (Hamiltonian value of the graph)
                graph_feature = self.linear(graph_feature) # of shape (n_samples, 1)
            return graph_feature

        # if all the inputs are using the same edge_index then forward all of them
        if isinstance(self.edge_index, tc.Tensor) and len(self.edge_index.shape) == 2:
            return _forward(x) # static edge index is specified, we can pass all the graphs at once
        elif len(self.edge_index) == 0:
            return _forward(x) # this scenario might happen during inference
        else:
            edge_indices = self.edge_index
            outputs = []
            for idx, graph in enumerate(x): # processing each graph separately because they have different edge indices
                self.edge_index = edge_indices[idx]
                outputs.append(_forward(graph[tc.newaxis, ...]))
            self.edge_index = edge_indices  # set back for logging purposes
            return tc.cat(outputs, dim=0)

    def sample_hidden(self, x: tc.Tensor, sampling_args: SamplingArgs,
                      e_pred: tc.Tensor | None = None):
        """
        Samples hidden layer parameters using SWIM algorithm.

        Args:
            x                   of shape (n_samples, n_features)
            e_pred              corresponding Hamiltonian values of shape (n_samples, 1) for SWIM algorithm.
                                if set to None then we fallback to uniform sampling in the data space.
            sampling_args       SamplingArgs
        """
        node_encoder = Dense(layer_width=self.enc_width, activation=parse_activ_f(self.activ_str), activ_str=self.activ_str,
                             parameter_sampler=sampling_args.param_sampler, random_seed=sampling_args.seed,
                             sample_uniformly=sampling_args.sample_uniformly, is_classifier=False,
                             prune_duplicates=False, resample_duplicates=sampling_args.resample_duplicates,
                             swim_dy_norm_ord=sampling_args.swim_dy_norm_ord,
                             elm_weight_loc=sampling_args.elm_weight_loc, elm_weight_std=sampling_args.elm_weight_std,
                             elm_bias_start=sampling_args.elm_bias_start, elm_bias_end=sampling_args.elm_bias_end,
                             dtype=sampling_args.dtype)
        edge_encoder = deepcopy(node_encoder)
        edge_encoder.random_seed = node_encoder.random_seed + 1

        msg_encoder = deepcopy(node_encoder)
        msg_encoder.random_seed = edge_encoder.random_seed + 1
        msg_encoder.layer_width = self.msg_width

        if sampling_args.param_sampler == "fourier":
            node_encoder.fourier_sigma = sampling_args.enc_sigma
            edge_encoder.fourier_sigma = sampling_args.enc_sigma
            msg_encoder.fourier_sigma = sampling_args.msg_sigma

        def sample_dense(layer, x, prepare_y=lambda y: y, disable_aswim=False):
            if not sampling_args.sample_uniformly and e_pred is None: # Initial sampling in Approximate-SWIM
                print("***(Approximate)-SWIM INITIAL-APPROXIMATION***")
                layer.sample_uniformly = True # for the initial fit of Approximate-SWIM, temporary only
                layer.fit(x)
                layer.sample_uniformly = False
            elif not sampling_args.sample_uniformly and e_pred is not None: # re-sampling for Approximate-SWIM
                if not disable_aswim:
                    print("***(Approximate)-SWIM RESAMPLING***")
                    print("x shape", x.shape)
                    print("e_pred shpae", e_pred.shape)
                    print("prepared shape", prepare_y(e_pred).shape)
                    # print("")
                    layer.fit(x, prepare_y(e_pred)) # Using approximate y values
                    # layer.fit(x, e_pred) # Using approximate y values
                else:
                    print("***A-SWIM IS DISABLED***")
                    # bypass aswim for this layer
                    layer.sample_uniformly = True
                    layer.fit(x)
                    layer.sample_uniformly = False
            elif sampling_args.sample_uniformly and e_pred is None: # uniform sampling
                # print("***Sampling uniformly***")
                layer.fit(x)
            else:
                raise ValueError("y values are provided but you specified sample_uniformly=True")

        # In order to sample the dense layers we need their inputs, this is step 1.
        # Step 2 is applying the sampling algorithm TODO: for data-agnostic sampling we can just sample the weights directly without doing step 1, this however shouldn't affect the time-to-solution much as the most of the time is spent to compute the least-squares (fitting the linear head)

        # Step 1: Prepare input for sampling node, edge and message encoders
        x = unflatten_TC(x, self.n_obj, self.dof)
        n_samples = len(x)

        # Extract q and p
        q, p = x[..., :self.dof], x[..., self.dof:]  # (n_samples, *n_obj, dof)

        # Get translation and rotation invariant representations
        q, p = self.normalizer(q, p)

        # Prepare node features (V), this is the input we need for sampling the node encoder
        node_features = self.get_node_features(q, p)        # of shape (n_samples * n_nodes, 2*dof)

        # Step 2: Sample node_encoder and get node encodings
        node_features = node_features.detach().numpy()
        sample_dense(node_encoder, node_features, disable_aswim=True)
        node_enc = node_encoder.transform(node_features).reshape(n_samples, *self.n_obj, self.enc_width)

        # Similar to the forward pass, we have different cases for the static and dynamic edge index cases
        if isinstance(self.edge_index, tc.Tensor) and len(self.edge_index.shape) == 2:
            # Step 1 for the edge encoder: prepare edge_features, this is the input to the edge encoder
            # static edge index
            edge_features = self.get_edge_features(q)       # of shape (n_samples * n_edges, dof + 1)
            edge_features = edge_features.detach().numpy()
            # Step 2: Sample edge_encoder and get edge encodings
            sample_dense(edge_encoder, edge_features, disable_aswim=True)
            edge_enc = edge_encoder.transform(edge_features).reshape(n_samples, -1, self.enc_width)

            # Sample the message encoder
            msg = self.locally_pool(tc.from_numpy(node_enc), tc.from_numpy(edge_enc)).detach().numpy()
            sample_dense(msg_encoder, msg, disable_aswim=True)
        else:
            # For Step 1 we need to pass the graphs one by one (due to dynamic edge indexing)
            edge_indices = self.edge_index
            edge_features = []
            for idx, q_item in enumerate(q):
                self.edge_index = edge_indices[idx]
                if len(self.edge_index) != 0:
                    edge_features.append(self.get_edge_features(q_item[tc.newaxis, ...]))
            if len(edge_features) == 0:
                raise ValueError("None of the graphs have any edges, therefore can't sample data-driven random features for the edge encoder. Make sure your data has at least one graph with a non-empty edge index.")
            self.edge_index = edge_indices  # set back for logging purposes
            edge_features = tc.cat(edge_features, dim=0).detach().numpy()

            # Step 2: Sample edge_encoder and get edge encodings
            sample_dense(edge_encoder, edge_features, disable_aswim=True)

            # Step 1 of computing the input to the message encoder
            msgs = []
            for idx, q_item in enumerate(q):
                self.edge_index = edge_indices[idx]
                if len(self.edge_index) != 0:
                    edge_features_item = self.get_edge_features(q_item[tc.newaxis, ...])
                    edge_features_item = edge_features_item.detach().numpy()
                    edge_enc_item = edge_encoder.transform(edge_features_item).reshape(1, -1, self.enc_width)
                    node_enc_item = tc.from_numpy(node_enc[0][np.newaxis, ...])
                    edge_enc_item = tc.from_numpy(edge_enc_item)
                    # TODO: here apply message passing if specified
                    msg = self.locally_pool(node_enc_item, edge_enc_item)
                    msgs.append(msg)
            if msgs == []:
                raise ValueError("There is no message passing geometry present in none of the inputted graphs.")
            self.edge_index = edge_indices
            msg = tc.cat(msgs, dim=0).detach().numpy()

            # Step 2: Sample message encoder
            sample_dense(msg_encoder, msg, disable_aswim=True)

        self.load_params({ "node_encoder_weight": node_encoder.weights, "node_encoder_bias": node_encoder.biases, "edge_encoder_weight": edge_encoder.weights, "edge_encoder_bias": edge_encoder.biases, "msg_encoder_weight": msg_encoder.weights, "msg_encoder_bias": msg_encoder.biases })

        # set scalar if random fourier features are specified
        if sampling_args.param_sampler == "fourier":
            self.encoder_scaler = sqrt(2.0 / self.enc_width)
            self.message_scaler = sqrt(2.0 / self.msg_width)
            self.activ = tc.cos     # random fourier feature specific activation

    def locally_pool(self, node_enc, edge_enc):
        """
        Returns final message sent out from nodes after applying local pooling
        msg of shape (n_samples * n_edges_total, 2 * enc_width) is returned
        """
        local_feat = node_enc
        # Step 1: Locally pool node features
        for _ in range(self.n_msg_passes - 1):
            # Prepare message src->dst as [src_node_encoding, edge_encoding]
            msg = self.prepare_msg(local_feat, edge_enc)      # (n_samples, n_edges_total, 2 * enc_width)

            # Pool into the node features and then continue
            local_pool = self.local_pooler(self.n_obj, msg, self.edge_index)  # of shape (n_samples, *n_obj, 2*enc_width)
            local_node_update, _ = local_pool.tensor_split(2, dim=-1)
            local_feat = local_node_update

        # Prepare the final message
        msg = self.prepare_msg(local_feat, edge_enc)      # (n_samples, n_edges_total, 2 * enc_width)
        return msg.reshape(-1, 2 * self.enc_width)       # (n_samples* n_edges_total, 2 * enc_width)

    def freeze_hidden_layers(self):
        for hidden_layer in [self.node_encoder, self.edge_encoder, self.msg_encoder]:
            hidden_layer.weight.requires_grad_(False)
            hidden_layer.bias.requires_grad_(False)

    def unfreeze_hidden_layers(self):
        for hidden_layer in [self.node_encoder, self.edge_encoder, self.msg_encoder]:
            hidden_layer.weight.requires_grad_(True)
            hidden_layer.bias.requires_grad_(True)

    def get_node_features(self, q: tc.Tensor, p: tc.Tensor) -> tc.Tensor:
        """
        Returns node features
        Args:
            q: Translation and rotation invariant positions of shape (n_samples, *n_obj, dof)
            p: Rotation invariant momenta of shape (n_samples, *n_obj, dof)
        Returns:
            node_features: of shape (n_samples * n_nodes, 2 * dof)
        """
        node_features = tc.cat([q, p], dim=-1)                  # V of shape (n_samples, *n_obj, 2 * dof)
        return node_features.view(-1, self.node_features_dim)   # of shape (n_samples * n_nodes, 2 * dof)

    def get_edge_features(self, q):
        """
        Args:
            q: Translation and rotation invariant positions of shape (n_samples, *n_obj, 2*dof)
        Returns:
            edge_features of shape (n_samples * n_edges, dof + 1)
        """
        q = q.view(len(q), -1, self.dof)                        # of shape (n_samples, n_nodes, dof)
        src, dst = self.edge_index                              # (n_edges,) each,
        delta_q = q[:, src] - q[:, dst]                         # of shape (n_samples, n_edges, dof)

        if self.take_absolute_diff:
            delta_q = tc.abs(delta_q)

        # This logic is risky as it does not conserve the graph structure
        # if box_size is not None:
            # delta_q -= box_size * tc.round(delta_q / box_size)  # PBC
        norm = tc.linalg.norm(delta_q, dim=-1, keepdim=True)    # of shape (n_samples, n_edges, 1)

        # if cutoff < tc.inf:
            # # ignore distance above the speified cutoff
            # mask = norm < cutoff
            # delta_q = delta_q * mask
            # norm = norm * mask

        edge_features = tc.cat([delta_q, norm], dim=-1)         # E of shape (n_samples, n_edges, dof + 1)
        return edge_features.view(-1, self.edge_features_dim)   # of shape (n_samples * n_edges, dof + 1)

    def prepare_msg(self, node_enc, edge_enc):
        """
        Computes msg_ij on the edges using source node encoding and edge encoding information as
        msg_ij = [node_enc_i, edge_enc_ij]
        Args:
            node_enc of shape (n_samples, *n_obj, enc_width)
            edge_enc of shape (n_samples, n_edges, enc_width)
        Returns:
            msg_edge of shape (n_samples, n_edges_total, 2 * enc_width)
        """
        src, dst = self.edge_index                                          # (n_edges,)
        node_enc = node_enc.reshape(len(node_enc), -1, self.enc_width)      # (n_samples, n_nodes, enc_width)

        src_enc = node_enc[:, src, :]                                       # (n_samples, n_edges, enc_width)
        edge_enc_fwd = edge_enc

        if self.direct:
            msg = tc.cat([src_enc, edge_enc_fwd], dim=-1)                   # (n_samples, n_edges, 2*enc_width)
        else:
            dst_enc = node_enc[:, dst, :]                                   # (n_samples, n_edges, enc_width)
            edge_enc_bwd = edge_enc

            msg_fwd = tc.cat([src_enc, edge_enc_fwd], dim=-1)               # (n_samples, n_edges, 2*enc_width)
            msg_bwd = tc.cat([dst_enc, edge_enc_bwd], dim=-1)               # (n_samples, n_edges, 2*enc_width)
            msg = tc.cat([msg_fwd, msg_bwd], dim=1)                         # (n_samples, 2*n_edges, 2*enc_width)
        return msg                                                          # (n_samples, n_edges_total, 2*enc_width)

    def global_pool(self, node_msg: tc.Tensor, reduce_dims: list[int], pool="sum"):
        """
        Args:
            node_msg of shape (n_samples, Nx, width)
            reduce_dims indices of nodes to perform the global reduce
        Returns:
            global_node of shape (n_samples, width)
        """
        if pool == "sum": return node_msg.sum(dim=reduce_dims)
        elif pool == "avg": return node_msg.mean(dim=reduce_dims)
        else: raise ValueError(f"Unknown pooling {pool}")

    def load_params(self, params):
        """
        Loads weights and biases defined as numpy ndarrays
        Args:
            params      dict of parameters of 'node_encoder', 'edge_encoder', 'msg_encoder' and 'linear' layers
        """
        load_linear_params_from_numpy(self.node_encoder, params["node_encoder_weight"], params["node_encoder_bias"])
        load_linear_params_from_numpy(self.edge_encoder, params["edge_encoder_weight"], params["edge_encoder_bias"])
        load_linear_params_from_numpy(self.msg_encoder, params["msg_encoder_weight"], params["msg_encoder_bias"])
        if "dense_weight" in params and "dense_bias" in params:
            load_linear_params_from_numpy(self.dense, params["dense_weight"], params["dense_bias"])
        if "linear_weight" in params and "linear_bias" in params:
            load_linear_params_from_numpy(self.linear, params["linear_weight"], params["linear_bias"])

    def init_params(self, method="xavier_normal"):
        """
        Initializes encoder weights and biases according to the given initialization method.

        Args:
            method      'kaiming_normal', 'kaiming_uniform', 'xavier_normal', 'xavier_uniform', or 'orthogonal'
        """
        match method:
            case "kaiming_normal":
                init_fun = nn.init.kaiming_normal_
            case "kaiming_uniform":
                init_fun = nn.init.kaiming_uniform_
            case "xavier_normal":
                init_fun = nn.init.xavier_normal_
            case "xavier_uniform":
                init_fun = nn.init.xavier_uniform_
            case "orthogonal":
                init_fun = nn.init.orthogonal_
            case _:
                raise NotImplementedError("unknown weight initialization")

        init_fun(self.node_encoder.weight)
        init_fun(self.edge_encoder.weight)
        init_fun(self.msg_encoder.weight)
        init_fun(self.linear.weight)

        match self.init_method:
            case "tanh":
                nn.init.zeros_(self.node_encoder.bias)
                nn.init.zeros_(self.edge_encoder.bias)
                nn.init.zeros_(self.msg_encoder.bias)
                nn.init.zeros_(self.linear.bias)
            case "relu":
                nn.init.constant_(self.node_encoder.bias, 0.01)
                nn.init.constant_(self.edge_encoder.bias, 0.01)
                nn.init.constant_(self.msg_encoder.bias, 0.01)
                nn.init.constant_(self.linear.bias, 0.01)
            case _:
                raise NotImplementedError("Unknown init_method for the learnable parameters")
