# src/objective.py

import time
import torch
import numpy as np
from torch import nn
from torch.utils.data import TensorDataset, DataLoader as TorchDataLoader
from sklearn.model_selection import KFold, train_test_split
import optuna
import torch.nn.functional as F

# PyTorch Geometric
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

# Your modules
from src.model import Weaker_First, Weaker_Middle, GraphLevelPredictor
from src.meta_graph import (
    compute_node_attention_scores,
    construct_subgraphs,
    construct_meta_graphs,
    adjust_edge_index,
)
from src.utils import (
    set_weights,
    normalize_features,
)
from src.train_utils import (
    compute_loss,
    model_acc,
    weight_loss,
    weight_update,
    quality_update,
)


def objective(trial, gnn_datasets, device):
    dataset1 = gnn_datasets[0]
    dataset2 = gnn_datasets[1]
    dataset3 = gnn_datasets[2]
    dataset4 = gnn_datasets[3]
    dataset5 = gnn_datasets[4]

    datasets = [dataset1, dataset2, dataset3, dataset4, dataset5]
    # Pass device to models and loaders

    global test_predictions
    start_time_total = time.time()

    thres =  trial.suggest_float("thres", 0.05, 1, step=0.05)
    layer_num = trial.suggest_categorical("layer_num", [2,3,4,5,6])
    threshold1 = trial.suggest_float("threshold1", 0.1, 1, step=0.1)
    batch_size = trial.suggest_categorical("batch_size", [8, 16, 32])
    hidden_dim = 32
    num_train = trial.suggest_categorical("num_train", [5, 10, 20, 30, 40])
    num_epochs = trial.suggest_categorical("num_epochs", [5, 10, 20, 30, 40])
    dropout1 = trial.suggest_float("dropout1", 0.1, 0.9, step=0.05)  # Dropout probability
    lr = trial.suggest_loguniform("lr", 1e-4, 1e-0)  # Learning rate (log scale for finer tuning)
    train_weak = trial.suggest_categorical("train_weak", [5, 10, 15, 20, 30, 40])
    train_middle = trial.suggest_categorical("train_middle", [5, 10, 15, 20, 30, 40])
    weight_decay_weak = trial.suggest_loguniform("weight_decay_weak", 1e-5, 1e-2)
    lr_weak = trial.suggest_loguniform("lr_weak", 1e-4, 1e-0)
    weight_decay_middle = trial.suggest_loguniform("weight_decay_middle", 1e-5, 1e-2)
    lr_middle = trial.suggest_loguniform("lr_middle", 1e-4, 1e-0)
    batch1 = trial.suggest_categorical("batch1", [8, 16, 32])
    best_accuracy = 0
    best_hyperparams = None
    best_cm = None 

    
    from sklearn.model_selection import KFold

    
    testing = {}
    training = {}
    validation = {}
    
    training_labels = {}
    testing_labels = {}
    validation_labels = {}
    
    training_edge_indices = {}
    testing_edge_indices = {}
    validation_edge_indices = {}

    adj_matrices = {}
    edge_indices = {}
    train_dataset = {}
    test_datasett = {}
    validation_datasett = {}
    accuracies = []
    test_predictions_all = []
    test_labels_all = []
    val_accuracies = []

    dataset_size = len(dataset1)  # Ensure we're using the right dataset size
    kf = KFold(n_splits=5, shuffle=True, random_state=1)
    
    indices = list(range(dataset_size))  # Ensure indices cover all sample
    for fold, (train_idx, test_idx) in enumerate(kf.split(indices)):
        print(f"Starting Fold {fold + 1} - Time: {time.time() - start_time_total:.4f} seconds")        
        print(f"Fold {fold + 1}:")
        train_idx, validation_idx = train_test_split(train_idx, test_size=0.2, random_state=1)

   
        for i, dataset in enumerate(datasets):
            data_lfp = np.array([data.x.numpy().T for data in dataset])  # Shape: (num_samples, 400, 21)
            data_trial = np.array([data.y.item() for data in dataset])  # Labels
            testing[i] = data_lfp[test_idx]
            training[i] = data_lfp[train_idx]
            validation[i] = data_lfp[validation_idx]
            
            training_labels[i] = data_trial[train_idx]
            testing_labels[i] = data_trial[test_idx]
            validation_labels[i] = data_trial[validation_idx]
            
            training_labels[i] = torch.tensor(training_labels[i], dtype=torch.long)
            testing_labels[i] = torch.tensor(testing_labels[i], dtype=torch.long)
            validation_labels[i] = torch.tensor(validation_labels[i], dtype=torch.long)  # Not used in training
        
            # Preserve corresponding edge_index for each sample
            training_edge_indices[i] = [dataset[j].edge_index for j in train_idx]
            testing_edge_indices[i] = [dataset[j].edge_index for j in test_idx]
            validation_edge_indices[i] = [dataset[j].edge_index for j in validation_idx]  # Not used in training

            
            # Define a sample "data" object to be used in the model initialization
            sample_data = Data(x=training[i][0].T, edge_index=training_edge_indices[i][0])  # Use first training sample
            num_classes = len(torch.unique(torch.tensor(data_trial, dtype=torch.long)))
            train_dataset[i] = []
            for sample in range(training[i].shape[0]):
                edge_index = training_edge_indices[i][sample]
            
                if isinstance(edge_index, list):  
                    edge_index = torch.stack(edge_index, dim=1)  # Convert list to tensor
            
                train_dataset[i].append(Data(x=training[i][sample].T, edge_index=edge_index))
    
            # Create test and validation datasets in the same format as train_dataset
            test_datasett[i] = []
            for sample in range(testing[i].shape[0]):
                edge_index = testing_edge_indices[i][sample]
            
                if isinstance(edge_index, list):  
                    edge_index = torch.stack(edge_index, dim=1)  # Convert list to tensor
            
                test_datasett[i].append(Data(x=testing[i][sample].T, edge_index=edge_index))
            
            validation_datasett[i] = []
            for sample in range(validation[i].shape[0]):
                edge_index = validation_edge_indices[i][sample]
            
                if isinstance(edge_index, list):  
                    edge_index = torch.stack(edge_index, dim=1)  # Convert list to tensor
            
                validation_datasett[i].append(Data(x=validation[i][sample].T, edge_index=edge_index))   
            training_edge_indices[i] = adjust_edge_index(torch.cat(training_edge_indices[i], dim=1), train_idx)
            testing_edge_indices[i] = adjust_edge_index(torch.cat(testing_edge_indices[i], dim=1), test_idx)  
            adj_mat = np.mean([np.corrcoef(data_lfp[sample].T) for sample in range(data_lfp.shape[0])], axis=0)
            adj_mat[adj_mat < thres] = 0
            adj_mat[adj_mat == 1] = 0
            adj_mat[adj_mat > 0] = 1
            
            # Convert to PyTorch tensor
            adj_matrices[i] = torch.tensor(adj_mat, dtype=torch.float)
            #print(adj_matrices[i])
            # Create edge_index from adjacency matrix
            edge_indices[i] = adj_matrices[i].nonzero().t().contiguous()
 
            training[i], testing[i], validation[i] = normalize_features(training[i], testing[i], validation[i]) 
           
        models = {i: [] for i in range(len(datasets))} 
        latent_rep = {i: [] for i in range(len(datasets))}
        quality_vec = {i: [] for i in range(len(datasets))}
        final_emb = {}
        test_emb = {}
        validation_emb = {}
    
        torch.manual_seed(1)
        
        start_time = time.time()
        for i, dataset in enumerate(datasets):            
            sample_data = training[i]  # First sample
            num_nodes = sample_data.shape[2]
            num_node_features = sample_data.shape[1]
            model = Weaker_First(num_nodes, num_node_features, num_classes).to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=lr_weak, weight_decay=weight_decay_weak)
            model.train()  
            start_time = time.time()
            for _ in range(train_weak):
                train_loss_compute = compute_loss(model, training[i], edge_indices[i], device)
                training_softprob = train_loss_compute[0]
                training_softprob = training_softprob.clone().detach().requires_grad_(True)
                optimizer.zero_grad()
                training_loss = F.nll_loss(training_softprob, training_labels[i]) 
                training_loss.backward()
                optimizer.step()
                train_acc = model_acc(model, training[i], edge_indices[i], training_labels[i], device)
            curr_feat = train_loss_compute[1]
            curr_latent = np.array(train_loss_compute[1])
            print(f"weakfirst time {time.time() - start_time:.4f} seconds")        
            models[i].append(model)
            latent_rep[i].append(curr_latent)
            weights = torch.ones(training[i].shape[0])
            weights = nn.functional.normalize(weights, p=2, dim=0)         
            err_rate = weight_loss(model, training[i], weights, edge_indices[i], training_labels[i], device)
            quality_vec[i].append(quality_update(err_rate))
            start_time = time.time()            
            for _ in range(layer_num - 1):
                model = Weaker_Middle(256, num_classes).to(device)
                optimizer = torch.optim.Adam(model.parameters(), lr=lr_middle, weight_decay=weight_decay_middle)
                model.train()
                training[i] = np.array(curr_feat) 
                for epoch in range(train_middle):
                    train_loss_compute = compute_loss(model, training[i], edge_indices[i], device)
                    training_softprob = train_loss_compute[0]
                    training_softprob = training_softprob.clone().detach().requires_grad_(True)
                    optimizer.zero_grad()
                    training_loss = F.nll_loss(training_softprob, training_labels[i]) 
                    training_loss.backward()
                    optimizer.step()                    
                    train_acc = model_acc(model, training[i], edge_indices[i], training_labels[i], device)
                    curr_feat = train_loss_compute[1]
                models[i].append(model)
                curr_latent = np.array(train_loss_compute[1])
                latent_rep[i].append(curr_latent)
                err_rate = weight_loss(model, training[i], weights, edge_indices[i], training_labels[i], device)
                quality_vec[i].append(quality_update(err_rate))             
                weights = weight_update(model, err_rate, training[i], weights, edge_indices[i], training_labels[i], device)                
            final_emb[i] = torch.zeros_like(torch.tensor(latent_rep[i][0])).to(device)
            for j in range(layer_num):
                final_emb[i] += quality_vec[i][j] * torch.tensor(latent_rep[i][j]).to(device)
    
            # Create training dataset
            train_tensor = torch.tensor(training_labels[i], dtype=torch.float32).view(-1, 1)
            dataset = TensorDataset(final_emb[i], train_tensor)
            train_loader = TorchDataLoader(dataset, batch_size=batch1, shuffle=True)        
            test_latent_rep = []         
            for j in range(layer_num):
                model = models[i][j]
                test_loss_compute = compute_loss(model, testing[i], edge_indices[i], device)
                test_latent_rep.append(np.array(test_loss_compute[1]))
                testing[i] = np.array(test_loss_compute[1])
            test_emb[i] = torch.zeros_like(torch.tensor(test_latent_rep[0])).to(device)          
            for j in range(layer_num):
                test_emb[i] += quality_vec[i][j] * torch.tensor(test_latent_rep[j]).to(device)
            
            # Create test dataset
            test_tensor = torch.tensor(testing_labels[i], dtype=torch.float32).view(-1, 1)
            test_dataset = TensorDataset(test_emb[i], test_tensor)
            test_loader = TorchDataLoader(test_dataset, batch_size=batch1, shuffle=True)
            validation_latent_rep = []
            
            for j in range(layer_num):
                model = models[i][j]
                validation_loss_compute = compute_loss(model, validation[i], edge_indices[i], device)
                validation_latent_rep.append(np.array(validation_loss_compute[1]))
                validation[i] = np.array(validation_loss_compute[1])
            validation_emb[i] = torch.zeros_like(torch.tensor(validation_latent_rep[0])).to(device)
            for j in range(layer_num):
                validation_emb[i] += quality_vec[i][j] * torch.tensor(validation_latent_rep[j]).to(device)
            
            # Create validation dataset
            validation_tensor = torch.tensor(validation_labels[i], dtype=torch.float32).view(-1, 1)
            validation_dataset = TensorDataset(validation_emb[i], validation_tensor)
            validation_loader = TorchDataLoader(validation_dataset, batch_size=batch1, shuffle=True)
         
        train_subgraphs = []
        val_subgraphs = []
        test_subgraphs = []
        train_subgraphs = []
        test_subgraphs = []
        validation_subgraphs = [] 
        for i in range(len(datasets)):
            encoded_train_data_list = []
            encoded_test_data_list = []
            encoded_validation_data_list = []
        
            # Encode Train Dataset
            for sample in range(len(train_dataset[i])):  # Iterate over Data objects
                node_features = final_emb[i][sample]  # Ensure shape [num_nodes, num_features]
                encoded_data = Data(x=node_features, edge_index=train_dataset[i][sample].edge_index)
                encoded_train_data_list.append(encoded_data)
        
            # Encode Test Dataset
            for sample in range(len(test_datasett[i])):  # Iterate over Data objects
                node_features = test_emb[i][sample]  # Ensure shape [num_nodes, num_features]
                encoded_data = Data(x=node_features, edge_index=test_datasett[i][sample].edge_index)
                encoded_test_data_list.append(encoded_data)
        
            # Encode Validation Dataset
            for sample in range(len(validation_datasett[i])):  # Iterate over Data objects
                node_features = validation_emb[i][sample]  # Ensure shape [num_nodes, num_features]
                encoded_data = Data(x=node_features, edge_index=validation_datasett[i][sample].edge_index)
                encoded_validation_data_list.append(encoded_data)

            attention_scores = torch.zeros_like(models[i][0].attention_scores)
            
            # Iterate over each model layer
            for j in range(layer_num):
                # Retrieve the attention scores from the current model
                current_attention_scores = models[i][j].attention_scores
                # Retrieve the quality score for the current model
                quality_score = quality_vec[i][j]                
                # Accumulate the weighted attention scores
                attention_scores += quality_score * current_attention_scores
            node_attention_scores_train = compute_node_attention_scores(
                attention_scores, train_dataset[i][0].edge_index, train_dataset[i][0].x.shape[0]
            )
        
            # Construct subgraphs for training set
            train_subgraphs.append(
                construct_subgraphs(node_attention_scores_train, train_dataset[i], encoded_train_data_list, threshold=threshold1)
            )
        
            # Compute node attention scores from extracted attention weights (Test)
            node_attention_scores_test = compute_node_attention_scores(
                attention_scores, test_datasett[i][0].edge_index, test_datasett[i][0].x.shape[0]
            )
    
            # Construct subgraphs for test set
            test_subgraphs.append(
                construct_subgraphs(node_attention_scores_test, test_datasett[i], encoded_test_data_list, threshold=threshold1)
            )
        
            # Compute node attention scores from extracted attention weights (Validation)
            node_attention_scores_validation = compute_node_attention_scores(
                attention_scores, validation_datasett[i][0].edge_index, validation_datasett[i][0].x.shape[0]
            )
        
            # Construct subgraphs for validation set
            validation_subgraphs.append(
                construct_subgraphs(node_attention_scores_validation, validation_datasett[i], encoded_validation_data_list, threshold=threshold1)
            )    
    
        train_meta_graphs = construct_meta_graphs(train_subgraphs)
        val_meta_graphs = construct_meta_graphs(validation_subgraphs)
        test_meta_graphs = construct_meta_graphs(test_subgraphs)       
        
        # Assign labels to meta-graphs explicitly
        for i, graph in enumerate(train_meta_graphs):
            graph.y = torch.tensor(training_labels[0][i], dtype=torch.long)  # Assign correct labels
        
        for i, graph in enumerate(val_meta_graphs):
            graph.y = torch.tensor(validation_labels[0][i], dtype=torch.long)  # Assign validation labels
        
        for i, graph in enumerate(test_meta_graphs):
            graph.y = torch.tensor(testing_labels[0][i], dtype=torch.long)  # Assign test labels
        


        for graph in train_meta_graphs:
            if graph.y.dim() == 0:  # If the label is scalar
                graph.y = graph.y.unsqueeze(0)  # Convert to tensor of shape [1]
        for graph in val_meta_graphs:
            if graph.y.dim() == 0:
                graph.y = graph.y.unsqueeze(0)
        
        for graph in test_meta_graphs:
            if graph.y.dim() == 0:
                graph.y = graph.y.unsqueeze(0)
        
        from torch_geometric.loader import DataLoader  # Use PyG's DataLoader
        
        train_loader = DataLoader(train_meta_graphs, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_meta_graphs, batch_size=batch_size, shuffle=False)
        test_loader = DataLoader(test_meta_graphs, batch_size=batch_size, shuffle=False)       
        final_model = GraphLevelPredictor(input_dim=128, hidden_dim=hidden_dim, output_dim=5, dropout= dropout1)
    
        final_model.load_state_dict(final_model.state_dict())
        torch.manual_seed(1)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False 

    
        set_weights(final_model)
        optimizer = torch.optim.Adam(final_model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        
        best_val_loss = float("inf")
        best_model_state = None
        best_val_loss = float("inf")
        best_model_state = None
        
        # Initialize accuracy tracking lists
    
        
        # Training loop
        for epoch in range(num_train):
            final_model.train()
            total_loss = 0
            for batch in train_loader:
                optimizer.zero_grad()
                out = final_model(batch)
                loss = criterion(out, batch.y.view(-1).long())
                loss.backward()
                optimizer.step()
                total_loss += loss.item() 
        correct_val, total_val = 0, 0
    
        with torch.no_grad():
            for batch in val_loader:
                out = final_model(batch)
                pred = torch.argmax(out, dim=1)
                correct_val += (pred == batch.y.view(-1).long()).sum().item()
                total_val += batch.y.size(0)
    
        val_accuracy = correct_val / total_val
        val_accuracies.append(val_accuracy)
    
        print(f"Fold {fold + 1} Validation Accuracy: {val_accuracy:.4f}")
    
    # Instead of test accuracy, return average validation accuracy for Optuna
    final_val_accuracy = np.mean(val_accuracies)
    print(f"Final Average Validation Accuracy across folds: {final_val_accuracy:.4f}")
    return final_val_accuracy
