import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv


class NPYGraphDataset(InMemoryDataset):
    def __init__(self, npy_path, transform=None):
        super(NPYGraphDataset, self).__init__(None, transform)
        self.data, self.slices = self.load_npy(npy_path)
        self._num_features = self.data.x.size(1) if self.data.x is not None else 0

    @property
    def num_features(self):
        return self._num_features

    def load_npy(self, npy_path):
        np_data = np.load(npy_path, allow_pickle=True).item()

        node_features = np_data.get('node_features')
        edge_indices = np_data['edge_index']
        labels = np_data['label']

        data_list = []
        for feats, edges, label in zip(node_features, edge_indices, labels):
            feats = torch.tensor(feats, dtype=torch.float)
            edges = edges.clone().detach() if isinstance(edges, torch.Tensor) else torch.tensor(edges, dtype=torch.long)
            if edges.ndim == 2 and edges.size(0) != 2:
                edges = edges.t()
            edges = edges.contiguous()
            y = torch.tensor(label, dtype=torch.long)  
            graph_data = Data(x=feats, edge_index=edges, y=y)
            data_list.append(graph_data)

        return self.collate(data_list)

from improved_filter import ImprovedCausalFilter


class NPYGraphDataset(InMemoryDataset):
    def __init__(self, npy_path, transform=None):
        super(NPYGraphDataset, self).__init__(None, transform)
        self.data, self.slices = self.load_npy(npy_path)
        self._num_features = self.data.x.size(1) if self.data.x is not None else 0

    @property
    def num_features(self):
        return self._num_features

    def load_npy(self, npy_path):
        np_data = np.load(npy_path, allow_pickle=True).item()

        node_features = np_data.get('node_features')
        edge_indices = np_data['edge_index']
        labels = np_data['label']

        data_list = []
        for feats, edges, label in zip(node_features, edge_indices, labels):
            feats = torch.tensor(feats, dtype=torch.float)
            edges = edges.clone().detach() if isinstance(edges, torch.Tensor) else torch.tensor(edges, dtype=torch.long)
            if edges.ndim == 2 and edges.size(0) != 2:
                edges = edges.t()
            edges = edges.contiguous()
            y = torch.tensor(label, dtype=torch.long)  
            graph_data = Data(x=feats, edge_index=edges, y=y)
            data_list.append(graph_data)

        return self.collate(data_list)


class GCNWithImprovedFilter(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_channels=16, use_causal_filter=False, filter_config=None):
        super(GCNWithImprovedFilter, self).__init__()
        self.use_causal_filter = use_causal_filter
        
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, num_classes)
        
        if self.use_causal_filter and filter_config:
            self.input_filter = ImprovedCausalFilter(num_features, **filter_config.get('input', {}))
            self.hidden_filter = ImprovedCausalFilter(hidden_channels, **filter_config.get('hidden', {}))
            self.filters = [self.input_filter, self.hidden_filter]
        else:
            self.filters = []

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        if self.use_causal_filter:
            x = self.input_filter(x)
            
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        
        if self.use_causal_filter:
            x = self.hidden_filter(x)
            
        x = self.conv2(x, edge_index)
        return x

    def step_epoch(self):
        if self.use_causal_filter:
            for f in self.filters:
                f.step()

    def get_filter_info(self):
        if not self.use_causal_filter:
            return "Causal filter not in use."
        
        info_str = "Filter Stats:\n"
        stats1 = self.input_filter.get_stats()
        stats2 = self.hidden_filter.get_stats()
        
        info_str += f"  Input Filter:  lambda={stats1['lambda']:.4f}, gate_mean={stats1['gate_stats']['mean']:.4f}\n"
        info_str += f"  Hidden Filter: lambda={stats2['lambda']:.4f}, gate_mean={stats2['gate_stats']['mean']:.4f}"
        return info_str


def train_model(model, data_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_examples = 0
    for batch in data_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch)
        
        if out.size(0) != batch.y.size(0):
            
            from torch_geometric.nn import global_mean_pool
            out = global_mean_pool(out, batch.batch)

        log_probs = F.log_softmax(out, dim=1)
        
        
        target = batch.y
        if target.ndim > 1:
            target = target.squeeze()
        
        loss = F.nll_loss(log_probs, target) 
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * batch.num_graphs
        pred = out.argmax(dim=1)
        total_correct += int((pred == target).sum().item())
        total_examples += batch.num_graphs
        
    avg_loss = total_loss / total_examples
    accuracy = total_correct / total_examples
    return avg_loss, accuracy


def evaluate_model(model, data_loader, device):
    model.eval()
    total_correct = 0
    total_examples = 0
    with torch.no_grad():
        for batch in data_loader:
            batch = batch.to(device)
            out = model(batch)
            if out.size(0) != batch.y.size(0):
                from torch_geometric.nn import global_mean_pool
                out = global_mean_pool(out, batch.batch)
            pred = out.argmax(dim=1)
            target = batch.y
            if target.ndim > 1:
                target = target.squeeze()
            total_correct += int((pred == target).sum().item())
            total_examples += batch.num_graphs
    return total_correct / total_examples


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--use_causal_filter', action='store_true', default=True, help='Enable causal filter')
    args = parser.parse_args()

    for i in ['crcg']:
        print(f'Current experiment: {i}')
        
        train_dataset = NPYGraphDataset(f"./data/molecular/train_{i}.npy")
        val_dataset = NPYGraphDataset(f"./data/molecular/val_{i}.npy")
        test_dataset = NPYGraphDataset(f"./data/molecular/test_{i}.npy")

        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

        
        num_features = train_dataset.num_features
        num_classes = int(train_dataset.data.y.max().item()) + 1
        
        filter_config = {
            'input': {'lambda_init': 1.0, 'decay_rate': 0.99},
            'hidden': {'lambda_init': 1.0, 'decay_rate': 0.99},
        }

        model = GCNWithImprovedFilter(
            num_features, 
            num_classes, 
            hidden_channels=64,
            use_causal_filter=args.use_causal_filter,
            filter_config=filter_config if args.use_causal_filter else None
        )

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = model.to(device)
        print(f"Using device: {device}, Causal filter module: {'Enabled' if args.use_causal_filter else 'Disabled'}")

        
        criterion = torch.nn.NLLLoss() 
        optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

        
        epochs = 100
        val_acc_list_5_epoch = []
        final_train_accuracies = []
        final_in_test_accuracies = []
        final_test_accuracies = []
        best_val_acc = 0

        for epoch in range(1, epochs + 1):
            train_loss, train_acc = train_model(model, train_loader, optimizer, criterion, device)
            val_acc = evaluate_model(model, val_loader, device)
            test_acc = evaluate_model(model, test_loader, device)
            
            final_train_accuracies.append(train_acc * 100)
            final_in_test_accuracies.append(val_acc * 100)
            final_test_accuracies.append(test_acc * 100)
            val_acc_list_5_epoch.append(val_acc)

            if args.use_causal_filter:
                model.step_epoch()

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                

            if epoch % 10 == 0:
                print(f"Epoch {epoch:03d} | Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | Test Acc: {test_acc:.4f}")
                if args.use_causal_filter:
                    print(model.get_filter_info())

            if epoch % 5 == 0:
                std_val = np.std(val_acc_list_5_epoch)
                val_acc_list_5_epoch = []

        final_train_mean = np.mean(final_train_accuracies)
        final_train_std = np.std(final_train_accuracies)
        final_test_mean = np.mean(final_test_accuracies)
        final_test_std = np.std(final_test_accuracies)

        print(f"Train Acc  {final_train_mean:.2f} ± {final_train_std:.2f}  |  Test Acc  {final_test_mean:.2f} ± {final_test_std:.2f}")
