import torch
import torch.nn as nn

class node_pred(nn.Module):
    def __init__(self, num_series:int, num_confound:int, num_env:int, timestep:int, hidden:list):
        """
        A single-node predictor inside InvarGC;
        :param num_series: number of observed time series.
        :param num_confound: number of latent temporal confounders.
        :param num_env: number of environments.
        :param hidden: a list of fully connected neural networks.
        """
        super(node_pred,self).__init__()
        self.d = num_series
        self.p = num_confound
        self.T = timestep
        self.nums = num_series + num_confound
        self.gc_layer = nn.Conv1d(self.nums, hidden[0], kernel_size=1, bias=True)
        self.conf = nn.Parameter(torch.randn(num_env, 1, num_confound, timestep), requires_grad=True)
        self.activation = nn.ReLU()

        # Intervention Identification Network
        self.itv_layers = nn.ModuleList(
            [nn.Conv1d(self.d, hidden[0], kernel_size=1, bias=True) for _ in range(num_env)]
        )

        # Next-timestep Prediction Network
        dims = hidden + [1]
        self.hidden_layers = nn.ModuleList(
            [nn.Conv1d(dims[i], dims[i + 1], kernel_size=1, bias=True) for i in range(len(dims) - 1)]
        )

    def forward(self, x, idx):
        """
        :param x: [B,T,d].
        :param idx: env index.
        :return: next timestep prediction.
        """
        B, T, d = x.shape
        assert d == self.d, "Mismatch: last dimension of x must equal d (num observed)"
        assert T == self.T, "Mismatch: sequence length must equal T"

        x_ch = x.transpose(2,1) # [B, d, T]
        z = self.conf[idx]
        # Embedding Function
        zx = torch.cat([z, x_ch], dim=1)  # [1,2,1000] + [1,5,1000].
        h = self.gc_layer(zx) + self.itv_layers[idx](x_ch)

        for k, conv in enumerate(self.hidden_layers):
            if k > 0:
                h = self.activation(h)
            h = conv(h)
        y = h.transpose(2,1)
        return y


class InvarGC(nn.Module):
    def __init__(self, num_series:int, num_confound:int, num_env:int, timestep:int, hidden:list):
        """
        :param num_series: number of observed time series.
        :param num_confound: number of confounders.
        :param num_env: number of patches.
        :param hidden: a list of fully connected neural networks.
        """
        super(InvarGC, self).__init__()
        self.d = num_series
        self.p = num_confound
        self.n = num_env
        self.T = timestep

        self.networks = nn.ModuleList(
            [node_pred(num_series, num_confound, num_env, timestep, hidden) for _ in range(self.d)]
        )

    def forward(self, x):
        return torch.cat([network(x) for network in self.networks], dim=2)

    @torch.no_grad()
    def est_gc(self, threshold=None):
        """
        Estimated Granger strengths among observed variables.
        """
        cols = []
        for net in self.networks:
            # net.layer.weight: [hidden0, p+d, 1] -> group norm over (0, 2) => [p+d]
            w = torch.norm(net.gc_layer.weight, dim=[0,2])
            cols.append(w[self.p:])  # keep observed inputs only
        G = torch.stack(cols, dim=0)  # [d, d]

        if threshold is None:
            return G
        return (G > threshold).int()

    @torch.no_grad()
    def est_lc(self):
        """
        Estimated Latent Confounder strengths among observed variables.
        """
        cols = []
        for net in self.networks:
            w = torch.norm(net.gc_layer.weight, dim=[0,2])
            cols.append(w[:self.p])
        return torch.stack(cols, dim=0)  # [d, p]

    @torch.no_grad()
    def est_itv(self):
        """
        Estimated edge-level interventions among observed variables.
        """
        out = []
        for k in range(self.n):
            cols = []
            for net in self.networks:
                w = torch.norm(net.itv_layers[k].weight, dim=[0, 2])  # [d]
                cols.append(w)
            out.append(torch.stack(cols, dim=0))  # [d, d]
        return out