#%%
import torch
from model_classes import GCNNet, GraphSAGENet, GATNet, GINNet

print("Model classes imported successfully.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seeds = [42, 1997, 1901,
        1012,1015,
        1742, 20563,
        21000,568425,
        99687523,8956742,
        444458585,1111155555,
        1648568479,1778956423]

#%%


from sklearn.model_selection import train_test_split
from collections import defaultdict


def collect_by_ids(graph_ids):
    return [data for gid in graph_ids for data in graph_id_groups[gid]]

# Load the datasets
random_data_list = torch.load("embeddings/embedding-random_graphs.pt",weights_only=False)
gpt_data_list = torch.load("embeddings/updated_embedding-gpt_generated_graphs3.pt",weights_only=False)

# 1. Combine all data and assign graph_id attribute
all_graphs = random_data_list + gpt_data_list
for g in all_graphs:
    g.graph_id = g.paper_id  # set graph_id for grouping

# 2. Group all graphs by their graph_id
graph_id_groups = defaultdict(list)
for data in all_graphs:
    graph_id_groups[data.graph_id].append(data)
all_graph_ids = list(graph_id_groups.keys())

#%%
import json

hyperparameter_setup = json.load(open("sweep_parameters.json"))

print(f"Hyperparameter setup loaded: {hyperparameter_setup}")
#%% Assign simple graph_index and drop the original graph_id (optional cleanup)
def clean_and_index(dataset_split):
    for i, data in enumerate(dataset_split):
        data.graph_index = i              # new 0…N-1 index
        if hasattr(data, 'graph_id'):     # remove leak source
            del data.graph_id
        if hasattr(data, 'graph_type'):   # remove type attribute
            del data.graph_type
#%% Training and evaluation setup
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, roc_auc_score
import torch.nn.functional as F

class EarlyStopping:
    def __init__(self, patience):
        self.patience   = patience
        self.best_loss  = float('inf')
        self.counter    = 0
        self.early_stop = False

    def __call__(self, val_loss):
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        return self.early_stop


# === Training and Evaluation ===
def train_one_epoch(model, optimizer, loader, device):
    model.train()
    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = F.cross_entropy(out, data.y)
        assert data.y.dtype == torch.long and len(data.y.shape) == 1
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    return total_loss / len(loader)

def evaluate(model, loader, device):
    model.eval()
    total_loss = 0
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch) 
        loss = F.cross_entropy(out, data.y)
        total_loss += loss.item()
    return total_loss / len(loader)

def eval_metrics(model, loader, device):
    model.eval()
    y_true, y_pred, y_prob = [], [], []
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch) 
        prob = F.softmax(out, dim=1)[:, 1].detach().cpu().numpy()
        pred = out.argmax(dim=1).cpu().numpy()
        y = data.y.cpu().numpy()
        y_true.extend(y)
        y_pred.extend(pred)
        y_prob.extend(prob)
    return {
        'accuracy': accuracy_score(y_true, y_pred),
        'f1': f1_score(y_true, y_pred, average='macro'),
        'precision': precision_score(y_true, y_pred, average='macro'),
        'recall': recall_score(y_true, y_pred, average='macro'),
        'auc': roc_auc_score(y_true, y_prob),
    }
model_classses = {
    'GCNNet': GCNNet,
    'GraphSAGENet': GraphSAGENet,
    'GATNet': GATNet,
    'GINNet': GINNet
}
#%%
import wandb
import numpy as np
from torch_geometric.loader import DataLoader

from tqdm import tqdm
in_dim = 3072


def main():
    wandb.init()
    config = wandb.config
    wandb.run.name = f"{config['model_name']}_bs{config['batch_size']}_h{config['hidden']}_l{config['num_layers']}_lr{config['learning_rate']:.0e}"

    val_accuracy = []
    val_f1 = []
    val_precision = []
    val_recall = []
    val_auc = []

    for random_seed in seeds:
        print(f"Using random seed: {random_seed}")
        torch.manual_seed(random_seed)
        # 3. Split graph_ids into train (70%) and temp (30%)
        train_ids, temp_ids = train_test_split(
            all_graph_ids,
            test_size=0.3,
            random_state=random_seed
        )

        # 4. Split temp into validation (15%) and test (15%)
        val_ids, test_ids = train_test_split(
            temp_ids,       
            test_size=0.5,
            random_state=random_seed
        )

        

        train_dataset = collect_by_ids(train_ids)
        val_dataset   = collect_by_ids(val_ids)
        test_dataset  = collect_by_ids(test_ids)
        clean_and_index(train_dataset)
        clean_and_index(val_dataset)
        clean_and_index(test_dataset)   

        loaders =  (
                DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True,exclude_keys=['node_names']),
                DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False,exclude_keys=['node_names']),
                DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False,exclude_keys=['node_names'])
            )      

        model_class = model_classses[config['model_name']]
        model = model_class(in_dim, config['hidden'], config['num_layers'], config['dropout']).to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
        early_stopper = EarlyStopping(patience=config['patience'])
        train_loader, val_loader, test_loader = loaders


        for epoch in tqdm(range(1, config["max_epochs"] + 1)):
            train_loss = train_one_epoch(model, optimizer, train_loader, device)
            train_metrics = eval_metrics(model, train_loader, device)

            val_loss = evaluate(model, val_loader, device)
            val_metrics = eval_metrics(model, val_loader, device)


            if epoch >= config.patience and early_stopper(val_loss):
                print(f"Early stopping at epoch {epoch} (patience={config.patience})")
                break
            
            

        val_accuracy.append(val_metrics['accuracy'])
        val_f1.append(val_metrics['f1'])
        val_precision.append(val_metrics['precision'])
        val_recall.append(val_metrics['recall'])
        val_auc.append(val_metrics['auc'])

    final_val_accuracy = np.mean(val_accuracy)
    final_val_accuracy_std = np.std(val_accuracy)

    final_val_f1 = np.mean(val_f1)
    final_val_f1_std = np.std(val_f1)

    final_val_precision = np.mean(val_precision)
    final_val_precision_std = np.std(val_precision)

    final_val_recall = np.mean(val_recall)
    final_val_recall_std = np.std(val_recall)

    final_val_auc = np.mean(val_auc)
    final_val_auc_std = np.std(val_auc)

    wandb.log({'final_val_accuracy': final_val_accuracy,
               'final_val_accuracy_std': final_val_accuracy_std,
               'final_val_f1': final_val_f1,
               'final_val_f1_std': final_val_f1_std,
               'final_val_precision': final_val_precision,
               'final_val_precision_std': final_val_precision_std,
               'final_val_recall': final_val_recall,
               'final_val_recall_std': final_val_recall_std,
               'final_val_auc': final_val_auc,
               'final_val_auc_std': final_val_auc_std})

#%%

sweep_id = wandb.sweep(
    sweep=hyperparameter_setup,
    project="embedding_gpt_groundtruth"
)

wandb.agent(
    sweep_id=sweep_id,
    function=main,
    count=1,
)
