#%%
import torch
from model_classes import GCNNet, GraphSAGENet, GATNet, GINNet
from collections import defaultdict
print("Model classes imported successfully.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seeds = [42, 1997, 1901,
        1012,1015,
        1742, 20563,
        21000,568425,
        99687523,8956742,
        444458585,1111155555,
        1648568479,1778956423]

from collections import Counter
from torch_geometric.data import InMemoryDataset

class MyGraphDataset(InMemoryDataset):
    def __init__(self, data_list=None):
        super().__init__('.')
        if data_list is not None:
            self.data, self.slices = self.collate(data_list)

new_path = "my_graph_dataset_gpt_groundfiltered.pt"
dataset = torch.load(new_path,weights_only=False)

graph_types = [data.graph_type for data in dataset]
type_counts = Counter(graph_types)

#Group all graphs by their original graph_id
graph_id_groups = defaultdict(list)
for data in dataset:
    graph_id_groups[data.graph_id].append(data)

all_graph_ids = list(graph_id_groups.keys())

#%% Assign simple graph_index and drop the original graph_id (optional cleanup)
def clean_and_index(dataset_split):
    for i, data in enumerate(dataset_split):
        data.graph_index = i              # new 0…N-1 index
        if hasattr(data, 'graph_id'):     # remove leak source
            del data.graph_id
        if hasattr(data, 'graph_type'):   # remove type attribute
            del data.graph_type

# Helper to collect your Data objects back into lists
def collect_by_ids(graph_ids):
    return [data for gid in graph_ids for data in graph_id_groups[gid]]

#%% Training and evaluation setup
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, roc_auc_score
import torch.nn.functional as F

class EarlyStopping:
    def __init__(self, patience):
        self.patience   = patience
        self.best_loss  = float('inf')
        self.counter    = 0
        self.early_stop = False

    def __call__(self, val_loss):
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        return self.early_stop


# === Training and Evaluation ===
def train_one_epoch(model, optimizer, loader, device):
    model.train()
    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = F.cross_entropy(out, data.y)
        assert data.y.dtype == torch.long and len(data.y.shape) == 1
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    return total_loss / len(loader)

def evaluate(model, loader, device):
    model.eval()
    total_loss = 0
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch) 
        loss = F.cross_entropy(out, data.y)
        total_loss += loss.item()
    return total_loss / len(loader)

def eval_metrics(model, loader, device):
    model.eval()
    y_true, y_pred, y_prob = [], [], []
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch) 
        prob = F.softmax(out, dim=1)[:, 1].detach().cpu().numpy()
        pred = out.argmax(dim=1).cpu().numpy()
        y = data.y.cpu().numpy()
        y_true.extend(y)
        y_pred.extend(pred)
        y_prob.extend(prob)
    return {
        'accuracy': accuracy_score(y_true, y_pred),
        'f1': f1_score(y_true, y_pred, average='macro'),
        'precision': precision_score(y_true, y_pred, average='macro'),
        'recall': recall_score(y_true, y_pred, average='macro'),
        'auc': roc_auc_score(y_true, y_prob),
    }
model_classses = {
    'GCNNet': GCNNet,
    'GraphSAGENet': GraphSAGENet,
    'GATNet': GATNet,
    'GINNet': GINNet
}

#%%
import wandb
import numpy as np
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split

in_dim = 5


def main(cnfg):
    
    config = cnfg


    test_accuracy = []
    test_f1 = []
    test_precision = []
    test_recall = []
    test_auc = []

    for random_seed in seeds:
        print(f"Using random seed: {random_seed}")
        torch.manual_seed(random_seed)
        # 3. Split graph_ids into train (70%) and temp (30%)
        train_ids, temp_ids = train_test_split(
            all_graph_ids,
            test_size=0.3,
            random_state=random_seed
        )

        # 4. Split temp into validation (15%) and test (15%)
        val_ids, test_ids = train_test_split(
            temp_ids,       
            test_size=0.5,
            random_state=random_seed
        )

        

        train_dataset = collect_by_ids(train_ids)
        val_dataset   = collect_by_ids(val_ids)
        test_dataset  = collect_by_ids(test_ids)
        clean_and_index(train_dataset)
        clean_and_index(val_dataset)
        clean_and_index(test_dataset)   

        loaders =  (
                DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True,exclude_keys=['node_names']),
                DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False,exclude_keys=['node_names']),
                DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False,exclude_keys=['node_names'])
            )      

        model_class = model_classses[config['model_name']]
        model = model_class(in_dim, config['hidden'], config['num_layers'], config['dropout']).to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
        early_stopper = EarlyStopping(patience=config['patience'])
        train_loader, val_loader, test_loader = loaders


        for epoch in range(1, config["max_epochs"] + 1):
            train_loss = train_one_epoch(model, optimizer, train_loader, device)
            train_metrics = eval_metrics(model, train_loader, device)

            val_loss = evaluate(model, val_loader, device)
            

            test_loss = evaluate(model, test_loader, device)
            test_metrics = eval_metrics(model, test_loader, device)

            if epoch >= config['patience'] and early_stopper(val_loss):
                print(f"Early stopping at epoch {epoch} (patience={config['patience']})")
                break
            
            

        test_accuracy.append(test_metrics['accuracy'])
        test_f1.append(test_metrics['f1'])
        test_precision.append(test_metrics['precision'])
        test_recall.append(test_metrics['recall'])
        test_auc.append(test_metrics['auc'])

    final_test_accuracy = np.mean(test_accuracy)
    final_test_accuracy_std = np.std(test_accuracy)

    final_test_f1 = np.mean(test_f1)
    final_test_f1_std = np.std(test_f1)

    final_test_precision = np.mean(test_precision)
    final_test_precision_std = np.std(test_precision)

    final_test_recall = np.mean(test_recall)
    final_test_recall_std = np.std(test_recall)

    final_test_auc = np.mean(test_auc)
    final_test_auc_std = np.std(test_auc)

    final_results = {'final_test_accuracy': final_test_accuracy,
               'final_test_accuracy_std': final_test_accuracy_std,
               'final_test_f1': final_test_f1,
               'final_test_f1_std': final_test_f1_std,
               'final_test_precision': final_test_precision,
               'final_test_precision_std': final_test_precision_std,
               'final_test_recall': final_test_recall,
               'final_test_recall_std': final_test_recall_std,
               'final_test_auc': final_test_auc,
               'final_test_auc_std': final_test_auc_std}

    print(
        f"Model: {config['model_name']} \n Accuracy: {final_test_accuracy} \pm {final_test_accuracy_std}  \n F1: {final_test_f1} \pm {final_test_f1_std} "
    )

config_GCN = {
    'batch_size': 13000,
    'dropout': 0.14935278283713443,
    'hidden': 64,
    'learning_rate': 0.004736681622124355,
    'max_epochs': 500,
    'model_name': "GCNNet",
    'num_layers': 1,
    'patience': 15,
    'weight_decay': 0.0006118275179875643
}

config_GIN = {
    'batch_size':13000,
    'dropout':0.008756805905831555,
    'hidden':32,
    'learning_rate':0.004737909511043368,
    'max_epochs':500,
    'model_name':"GINNet",
    'num_layers':1,
    'patience':15,
    'weight_decay':0.0025742636561280987
}

config_GAT = {
    'batch_size':8000,
    'dropout':0.16245366605697703,
    'hidden':64,
    'learning_rate':0.0010269679637545368,
    'max_epochs':500,
    'model_name':"GATNet",
    'num_layers':1,
    'patience':15,
    'weight_decay':0.0003443990126653762
}

config_GraphSage = {
    'batch_size':8000,
    'dropout':0.2391449021551822,
    'hidden':32,
    'learning_rate':0.0008881393403007299,
    'max_epochs':500,
    'model_name':"GraphSAGENet",
    'num_layers':1,
    'patience':15,
    'weight_decay':0.00001845659446619629
    
}


main(config_GCN)
main(config_GIN)
main(config_GAT)
main(config_GraphSage)
