import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.nn import ResGatedGraphConv, GINConv, GCNConv, MessagePassing
import torch_geometric as pyg
from torch.autograd.functional import jacobian
import time
import matplotlib.pyplot as plt
import numpy as np
from typing import Callable
import pandas as pd
from typing import Optional
import math


class GCN_SSM(nn.Module):
    def __init__(
        self,
        nfeat: int,
        nhid: int,
        nclass: int,
        conv_func: str,
        nlayers: int = 4,
        dataset: str = "Cora",
        bnorm: bool = False,
        lin: bool = False,
        shared: bool = False,
        dyn: bool = False,
        gamma_a: float = 1.0,
        device: torch.device = 'cpu',
    ) -> None:

        super(GCN_SSM, self).__init__()

        conv_dict = {
            "GCN_SSM": self._gcn_conv,
        }

        conv_func = conv_dict[conv_func]
        
        self.nlayers = nlayers

        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()

        self.dec = nn.Linear(nhid, nclass)
        self.enc = conv_func(nfeat,nhid) 

        self.bnorm = bnorm
        self.lin = lin
        self.shared = shared
        self.dyn = dyn
        self.gamma_a = gamma_a

        self.dataset = dataset

        self.device = device

    
        self.states = []
        self.inputs = []

        for i in range(nlayers):        
            self.convs.append(conv_func(nhid,nhid))
            self.bns.append(nn.BatchNorm1d(nhid))
            self.states.append(nn.Linear(nhid, nhid, bias=False).to(self.device))
            self.inputs.append(nn.Linear(nhid, nhid, bias=False).to(self.device))
               
        if self.dyn:
            self._initialize_dynamic()

    def _initialize_dynamic(self):
        """
        Initialize each state weight so that all eigenvalues = self.gamma_a.
        """
        for state in self.states:
            # grab weight as a complex64 tensor on the right device
            W = state.weight.to(self.device).to(torch.complex64)

            # eigen‐decompose
            eigvals, eigvecs = torch.linalg.eig(W)

            # build diagonal matrix of gamma_a’s
            D = torch.diag(
                torch.full((eigvecs.size(0),),
                        fill_value=self.gamma_a,
                        dtype=torch.complex64,
                        device=self.device)
            )

            # reconstruct W_new = V D V⁻¹
            W_new = eigvecs @ D @ torch.linalg.inv(eigvecs)

            # copy back only the real part (since original W was real)
            with torch.no_grad():
                state.weight.copy_(W_new.real)

    def _gcn_conv(self, nhid: int, nhid2:int) -> Callable:
        return GCNConv(nhid, nhid2)

    def conv(self, x: torch.Tensor, i: int, edge_index: torch.Tensor) -> torch.Tensor:
        # pick layer-0 if shared, else layer i
        idx = 0 if self.shared else i

        conv_layer = self.convs[idx]
        state_lin  = self.states[idx]
        input_lin  = self.inputs[idx]

        # graph convolution + optional ReLU
        out = conv_layer(x, edge_index)
        if not self.lin:
            out = torch.relu(out)

        # state-space update
        return state_lin(x) + input_lin(out)

    def forward(
        self, data: pyg.data.Data, eig: bool = False, plot_epoch: Optional[str] = None
    ):
        x = data.x
        self.edge_index = data.edge_index


        x = self.enc(x, self.edge_index)

        for i in range(self.nlayers): 
            x = self.conv(x, 0, self.edge_index) if self.shared else self.conv(x, i, self.edge_index)
            bn = self.bns[i]
            x = bn(x) if self.bnorm else x

        x = self.dec(x)

        return x

