import os
import time
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, models
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader as PyGDataLoader
from torch_geometric.nn import GCNConv, GINConv, GATConv, global_mean_pool, GPSConv,global_add_pool
from sklearn.model_selection import train_test_split
from PIL import Image
from transformers import ConvNextV2ForImageClassification, ViTForImageClassification, SwinForImageClassification
import networkx as nx
from tqdm import tqdm
from torch_geometric.utils import degree

class ConvNeXtV2Encoder(nn.Module):
    def __init__(self, model_variant="tiny"):
        super().__init__()
        model_name = f"facebook/convnextv2-{model_variant}-1k-224"
        self.model = ConvNextV2ForImageClassification.from_pretrained(model_name)
        self.feature_dim = self.model.config.hidden_sizes[-1]
        self.model.classifier = nn.Identity()
        
    def forward(self, x):
        outputs = self.model(x)
        return outputs.logits  # Since classifier is Identity, logits now represent features
        
class ViTEncoder(nn.Module):
    def __init__(self, model_variant="base"):
        super().__init__()
        model_name = f"google/vit-{model_variant}-patch16-224"
        self.model = ViTForImageClassification.from_pretrained(model_name)
        self.feature_dim = self.model.config.hidden_size
        self.model.classifier = nn.Identity()
        
    def forward(self, x):
        outputs = self.model(x)
        return outputs.logits

class SwinEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = SwinForImageClassification.from_pretrained(
            "microsoft/swin-tiny-patch4-window7-224"
        )
        self.feature_dim = self.model.config.hidden_size
        self.model.classifier = nn.Identity()
        
    def forward(self, x):
        outputs = self.model(x)
        return outputs.logits

class ResNetEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = models.resnet50(pretrained=True)
        self.feature_dim = self.model.fc.in_features
        self.model.fc = nn.Identity()
        
    def forward(self, x):
        return self.model(x)


class GCNEncoder(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=64, dropout_rate=0.5, num_layers=5):
        super().__init__()
        self.input_dim = input_dim
        self.feature_dim = hidden_dim
        self.dropout_rate = dropout_rate
        
        # GCN layer
        self.convs = nn.ModuleList([GCNConv(input_dim, hidden_dim)])
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        if x is None:
            x = torch.ones(batch.size(0), self.input_dim).to(edge_index.device)
        
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1: 
                x = F.relu(F.dropout(x, p=self.dropout_rate, training=self.training))
        
        # global pooling
        x = global_mean_pool(x, batch)
        return x

class GINEncoder(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=64, dropout_rate=0.5, num_layers=5):
        super().__init__()
        self.input_dim = input_dim
        self.feature_dim = hidden_dim
        self.dropout_rate = dropout_rate
        
        self.convs = nn.ModuleList()
        
        mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.convs.append(GINConv(mlp))
        
        for _ in range(num_layers - 1):
            mlp = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            )
            self.convs.append(GINConv(mlp))

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        if x is None:
            x = torch.ones(batch.size(0), self.input_dim).to(edge_index.device)
        
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:  
                x = F.relu(F.dropout(x, p=self.dropout_rate, training=self.training))
        
        x = global_mean_pool(x, batch)
        return x

class GATEncoder(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=64, dropout_rate=0.5, heads=2, num_layers=5):
        super().__init__()
        self.input_dim = input_dim
        self.feature_dim = hidden_dim
        self.dropout_rate = dropout_rate
        
        self.convs = nn.ModuleList()
        
        self.convs.append(GATConv(input_dim, hidden_dim // heads, heads=heads))
        for _ in range(num_layers - 1):
            self.convs.append(GATConv(hidden_dim, hidden_dim // heads, heads=heads))

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        if x is None:
            x = torch.ones(batch.size(0), self.input_dim).to(edge_index.device)
        
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1: 
                x = F.relu(F.dropout(x, p=self.dropout_rate, training=self.training))
        
        x = global_mean_pool(x, batch)
        return x

class ConvNeXtV2Model(nn.Module):
    def __init__(self, num_classes, model_variant="tiny", dropout_rate=0.5):
        super(ConvNeXtV2Model, self).__init__()
        model_name = f"facebook/convnextv2-{model_variant}-1k-224"
        self.model = ConvNextV2ForImageClassification.from_pretrained(
            model_name,
            num_labels=num_classes,
            ignore_mismatched_sizes=True
        )
        
        hidden_size = self.model.config.hidden_sizes[-1]
        
        # Replace classifier
        self.model.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 2, num_classes)
        )
        
    def forward(self, x):
        outputs = self.model(x)
        return outputs.logits
    
    def extract_features(self, x):
        return self.model.convnextv2(x).pooler_output
    

class ViTModel(nn.Module):
    def __init__(self, num_classes, model_variant="base", dropout_rate=0.5):
        super(ViTModel, self).__init__()
        model_name = f"google/vit-{model_variant}-patch16-224"
        self.model = ViTForImageClassification.from_pretrained(
            model_name,
            num_labels=num_classes,
            ignore_mismatched_sizes=True
        )
        
        hidden_size = self.model.config.hidden_size
        
        self.model.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 2, num_classes)
        )

    def forward(self, x):
        outputs = self.model(x)
        return outputs.logits
    
    def extract_features(self, x):
        """
        Extract features from the ViT model before the classification head.
        Returns the [CLS] token representation that the classifier would use.
        """
        outputs = self.model.vit(x, return_dict=True)
        
        # The classifier uses the [CLS] token (first token)
        # which is at position [:, 0, :] in the sequence output
        if hasattr(outputs, 'last_hidden_state'):
            # This is the standard output format in newer versions
            return outputs.last_hidden_state[:, 0, :]
        else:
            # Fallback for older versions returning tuples
            sequence_output = outputs[0]  # First item in the tuple is the hidden states
            return sequence_output[:, 0, :]  # Get the [CLS] token


class SwinModel(nn.Module):
    def __init__(self, num_classes, dropout_rate=0.5):
        super(SwinModel, self).__init__()
        self.model = SwinForImageClassification.from_pretrained(
            "microsoft/swin-tiny-patch4-window7-224",
            num_labels=num_classes,
            ignore_mismatched_sizes=True
        )
        
        hidden_size = self.model.config.hidden_size
        
        self.model.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 2, num_classes)
        )

    def forward(self, x):
        outputs = self.model(x)
        return outputs.logits
    
    def extract_features(self, x):
        return self.model.swin(x).pooler_output

class ResNetModel(nn.Module):
    def __init__(self, num_classes, dropout_rate=0.5):
        super(ResNetModel, self).__init__()
        self.model = models.resnet50(pretrained=True)
        
        hidden_size = self.model.fc.in_features
        
        self.model.fc = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 2, num_classes)
        )

    def forward(self, x):
        return self.model(x)
    
    def extract_features(self, x):
        # Get features just before the FC layer
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        
        x = self.model.avgpool(x)  
        x = torch.flatten(x, 1)
        return x
    
class GCNModel(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=64, num_classes=2, dropout_rate=0.5, num_layers=5):
        super(GCNModel, self).__init__()
        self.input_dim = input_dim
        self.dropout_rate = dropout_rate
        
        self.convs = nn.ModuleList([GCNConv(input_dim, hidden_dim)])
        
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
        
        self.linear = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        if x is None:
            x = torch.ones(batch.size(0), self.input_dim).to(edge_index.device)
    
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1: 
                x = F.relu(F.dropout(x, p=self.dropout_rate, training=self.training))
        
        x = global_mean_pool(x, batch)
        x = self.linear(x)
        return x
    
    def extract_features(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        if x is None:
            x = torch.ones(batch.size(0), self.input_dim).to(edge_index.device)
        
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = F.relu(F.dropout(x, p=self.dropout_rate, training=self.training))
        
        return global_mean_pool(x, batch)

class GINModel(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=64, num_classes=2, dropout_rate=0.5, num_layers=5):
        super(GINModel, self).__init__()
        self.input_dim = input_dim
        self.dropout_rate = dropout_rate
        
        self.convs = nn.ModuleList()
        
        mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.convs.append(GINConv(mlp))
        
        for _ in range(num_layers - 1):
            mlp = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            )
            self.convs.append(GINConv(mlp))
        
        self.linear = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        if x is None:
            x = torch.ones(batch.size(0), self.input_dim).to(edge_index.device)
        
        # GIN layer
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1: 
                x = F.relu(F.dropout(x, p=self.dropout_rate, training=self.training))
        
        x = global_mean_pool(x, batch)
        
        x = self.linear(x)
        return x
    
    def extract_features(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        if x is None:
            x = torch.ones(batch.size(0), self.input_dim).to(edge_index.device)
        
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = F.relu(F.dropout(x, p=self.dropout_rate, training=self.training))
        
        return global_mean_pool(x, batch)

class GATModel(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=64, num_classes=2, dropout_rate=0.5, heads=2, num_layers=5):
        super(GATModel, self).__init__()
        self.input_dim = input_dim
        self.dropout_rate = dropout_rate
        
        self.convs = nn.ModuleList()
        
        self.convs.append(GATConv(input_dim, hidden_dim // heads, heads=heads))
        
        for _ in range(num_layers - 1):
            self.convs.append(GATConv(hidden_dim, hidden_dim // heads, heads=heads))
        
        self.linear = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        if x is None:
            x = torch.ones(batch.size(0), self.input_dim).to(edge_index.device)
        
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:  
                x = F.relu(F.dropout(x, p=self.dropout_rate, training=self.training))
        
        x = global_mean_pool(x, batch)
        
        x = self.linear(x)
        return x
    def extract_features(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        if x is None:
            x = torch.ones(batch.size(0), self.input_dim).to(edge_index.device)
        
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = F.relu(F.dropout(x, p=self.dropout_rate, training=self.training))
        
        return global_mean_pool(x, batch)
    
class GPSModel(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=64, num_classes=2, dropout_rate=0.5, 
                 num_layers=5, heads=4):
        super().__init__()
        self.input_dim = input_dim
        self.dropout_rate = dropout_rate
        
        self.node_emb = nn.Linear(input_dim, hidden_dim)
        
        self.convs = nn.ModuleList()
        for _ in range(num_layers):
            gin_nn = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            conv = GPSConv(hidden_dim, GINConv(gin_nn), heads=heads)
            self.convs.append(conv)
            
        self.linear = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        # if x is None:
        #     x = torch.ones(batch.size(0), self.input_dim).to(edge_index.device)
            
        x = self.node_emb(x)
        
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index, batch)
            if i < len(self.convs) - 1:  
                x = F.relu(F.dropout(x, p=self.dropout_rate, training=self.training))
    
        x = global_mean_pool(x, batch)
        x = self.linear(x)
        return x
    
    def extract_features(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        x = self.node_emb(x)
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index, batch)
            if i < len(self.convs) - 1:
                x = F.relu(F.dropout(x, p=self.dropout_rate, training=self.training))
        return global_mean_pool(x, batch)
    
class ConvNeXtV2Encoder(nn.Module):
    def __init__(self, model_variant="tiny"):
        super().__init__()
        model_name = f"facebook/convnextv2-{model_variant}-1k-224"
        self.model = ConvNextV2ForImageClassification.from_pretrained(model_name)
        self.feature_dim = self.model.config.hidden_sizes[-1]
        self.model.classifier = nn.Identity()
        
    def forward(self, x):
        outputs = self.model(x)
        return outputs.logits  
        
class ViTEncoder(nn.Module):
    def __init__(self, model_variant="base"):
        super().__init__()
        model_name = f"google/vit-{model_variant}-patch16-224"
        self.model = ViTForImageClassification.from_pretrained(model_name)
        self.feature_dim = self.model.config.hidden_size
        self.model.classifier = nn.Identity()
        
    def forward(self, x):
        outputs = self.model(x)
        return outputs.logits

class SwinEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = SwinForImageClassification.from_pretrained(
            "microsoft/swin-tiny-patch4-window7-224"
        )
        self.feature_dim = self.model.config.hidden_size
        self.model.classifier = nn.Identity()
        
    def forward(self, x):
        outputs = self.model(x)
        return outputs.logits

class ResNetEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = models.resnet50(pretrained=True)
        self.feature_dim = self.model.fc.in_features
        self.model.fc = nn.Identity()
        
    def forward(self, x):
        return self.model(x)


class GCNEncoder(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=64, dropout_rate=0.5, num_layers=5):
        super().__init__()
        self.input_dim = input_dim
        self.feature_dim = hidden_dim
        self.dropout_rate = dropout_rate
        
        self.convs = nn.ModuleList([GCNConv(input_dim, hidden_dim)])
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        if x is None:
            x = torch.ones(batch.size(0), self.input_dim).to(edge_index.device)
        
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1: 
                x = F.relu(F.dropout(x, p=self.dropout_rate, training=self.training))
        
        x = global_mean_pool(x, batch)
        return x

class GINEncoder(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=64, dropout_rate=0.5, num_layers=5):
        super().__init__()
        self.input_dim = input_dim
        self.feature_dim = hidden_dim
        self.dropout_rate = dropout_rate
        
        self.convs = nn.ModuleList()
        
        mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.convs.append(GINConv(mlp))
        
        for _ in range(num_layers - 1):
            mlp = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            )
            self.convs.append(GINConv(mlp))

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        if x is None:
            x = torch.ones(batch.size(0), self.input_dim).to(edge_index.device)
        
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:  
                x = F.relu(F.dropout(x, p=self.dropout_rate, training=self.training))
        
        x = global_mean_pool(x, batch)
        return x

class GATEncoder(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=64, dropout_rate=0.5, heads=2, num_layers=5):
        super().__init__()
        self.input_dim = input_dim
        self.feature_dim = hidden_dim
        self.dropout_rate = dropout_rate
        
        self.convs = nn.ModuleList()
        
        self.convs.append(GATConv(input_dim, hidden_dim // heads, heads=heads))
        for _ in range(num_layers - 1):
            self.convs.append(GATConv(hidden_dim, hidden_dim // heads, heads=heads))

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        if x is None:
            x = torch.ones(batch.size(0), self.input_dim).to(edge_index.device)
        
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1: 
                x = F.relu(F.dropout(x, p=self.dropout_rate, training=self.training))
        
        x = global_mean_pool(x, batch)
        return x
    

class GPSEncoder(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=64, dropout_rate=0.5, 
                 num_layers=5, heads=4):
        super().__init__()
        self.input_dim = input_dim
        self.feature_dim = hidden_dim
        self.dropout_rate = dropout_rate
        
        self.node_emb = nn.Linear(input_dim, hidden_dim)
        
        self.convs = nn.ModuleList()
        for _ in range(num_layers):
            gin_nn = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            conv = GPSConv(hidden_dim, GINConv(gin_nn), heads=heads)
            self.convs.append(conv)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # if x is None:
        #     x = torch.ones(batch.size(0), self.input_dim).to(edge_index.device)
            
        x = self.node_emb(x)
        
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index, batch)
            if i < len(self.convs) - 1:
                x = F.relu(F.dropout(x, p=self.dropout_rate, training=self.training))
        
        x = global_mean_pool(x, batch)
        return x
