# src/train.py

import time
import torch
import numpy as np
import torch.nn.functional as F
from torch import nn
from torch.utils.data import TensorDataset, DataLoader as TorchDataLoader
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from sklearn.model_selection import KFold

# Custom imports from your project
from src.model import Weaker_First, Weaker_Middle, GraphLevelPredictor
from src.meta_graph import (
    compute_node_attention_scores,
    construct_subgraphs,
    construct_meta_graphs,
    adjust_edge_index,
)
from src.utils import set_weights, normalize_features_eval
from src.train_utils import (
    compute_loss,
    model_acc,
    weight_loss,
    weight_update,
    quality_update,
)


def evaluate_best_model(best_params, gnn_datasets, device, num_classes=5):

    from sklearn.model_selection import KFold
    from torch_geometric.loader import DataLoader
    print("\n🔍 Starting final evaluation using best hyperparameters...\n")
    # Pick first fold as test set
    dataset_size = len(gnn_datasets[0])
    indices = list(range(dataset_size))
    kf = KFold(n_splits=5, shuffle=True, random_state=1)
    train_idx, test_idx = next(kf.split(indices))

    dataset1 = gnn_datasets[0]
    dataset2 = gnn_datasets[1]
    dataset3 = gnn_datasets[2]
    dataset4 = gnn_datasets[3]
    dataset5 = gnn_datasets[4]
    
    datasets = [dataset1, dataset2, dataset3, dataset4, dataset5]

    
    training = {}
    testing = {}
    train_dataset = {}
    test_datasett = {}
    training_labels = {}
    testing_labels = {}
    training_edge_indices = {}
    testing_edge_indices = {}
    adj_matrices = {}
    edge_indices = {}
    thres = best_params["thres"]
    lr_weak = best_params["lr_weak"]
    weight_decay_weak = best_params["weight_decay_weak"]
    lr_middle = best_params["lr_middle"]
    weight_decay_middle = best_params["weight_decay_middle"]
    train_weak = best_params["train_weak"]
    train_middle = best_params["train_middle"]
    layer_num = best_params["layer_num"]
    threshold1 = best_params["threshold1"]
    batch1 = best_params["batch1"]



    for i, dataset in enumerate(datasets):
        data_lfp = np.array([data.x.numpy().T for data in dataset])  # Shape: (num_samples, 400, 21)
        data_trial = np.array([data.y.item() for data in dataset])  # Labels
        testing[i] = data_lfp[test_idx]
        training[i] = data_lfp[train_idx]
        
        training_labels[i] = data_trial[train_idx]
        testing_labels[i] = data_trial[test_idx]
        
        training_labels[i] = torch.tensor(training_labels[i], dtype=torch.long)
        testing_labels[i] = torch.tensor(testing_labels[i], dtype=torch.long)
        # Preserve corresponding edge_index for each sample
        training_edge_indices[i] = [dataset[j].edge_index for j in train_idx]
        testing_edge_indices[i] = [dataset[j].edge_index for j in test_idx]

        # Define a sample "data" object to be used in the model initialization
        sample_data = Data(x=training[i][0].T, edge_index=training_edge_indices[i][0])  # Use first training sample
        num_classes = len(torch.unique(torch.tensor(data_trial, dtype=torch.long)))
        train_dataset[i] = []
        for sample in range(training[i].shape[0]):
            edge_index = training_edge_indices[i][sample]
        
            if isinstance(edge_index, list):  
                edge_index = torch.stack(edge_index, dim=1)  # Convert list to tensor
        
            train_dataset[i].append(Data(x=training[i][sample].T, edge_index=edge_index))

        test_datasett[i] = []
        for sample in range(testing[i].shape[0]):
            edge_index = testing_edge_indices[i][sample]
        
            if isinstance(edge_index, list):  
                edge_index = torch.stack(edge_index, dim=1)  # Convert list to tensor
        
            test_datasett[i].append(Data(x=testing[i][sample].T, edge_index=edge_index))

        training_edge_indices[i] = adjust_edge_index(torch.cat(training_edge_indices[i], dim=1), train_idx)
        testing_edge_indices[i] = adjust_edge_index(torch.cat(testing_edge_indices[i], dim=1), test_idx)  
        adj_mat = np.mean([np.corrcoef(data_lfp[sample].T) for sample in range(data_lfp.shape[0])], axis=0)
        adj_mat[adj_mat < thres] = 0
        adj_mat[adj_mat == 1] = 0
        adj_mat[adj_mat > 0] = 1
        
        # Convert to PyTorch tensor
        adj_matrices[i] = torch.tensor(adj_mat, dtype=torch.float)
        
        # Create edge_index from adjacency matrix
        edge_indices[i] = adj_matrices[i].nonzero().t().contiguous()

        training[i], testing[i] = normalize_features_eval(training[i], testing[i]) 
       
    models = {i: [] for i in range(len(datasets))} 
    latent_rep = {i: [] for i in range(len(datasets))}
    quality_vec = {i: [] for i in range(len(datasets))}
    final_emb = {}
    test_emb = {}

    torch.manual_seed(1)
    
    start_time = time.time()
    for i, dataset in enumerate(datasets):            
        sample_data = training[i]  # First sample
        num_nodes = sample_data.shape[2]
        num_node_features = sample_data.shape[1]
        model = Weaker_First(num_nodes, num_node_features, num_classes).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr_weak, weight_decay=weight_decay_weak)
        model.train()  
        start_time = time.time()
        for _ in range(train_weak):
            train_loss_compute = compute_loss(model, training[i], edge_indices[i], device)
            training_softprob = train_loss_compute[0]
            training_softprob = training_softprob.clone().detach().requires_grad_(True)
            optimizer.zero_grad()
            training_loss = F.nll_loss(training_softprob, training_labels[i]) 
            training_loss.backward()
            optimizer.step()
            train_acc = model_acc(model, training[i], edge_indices[i], training_labels[i], device)
        curr_feat = train_loss_compute[1]
        curr_latent = np.array(train_loss_compute[1])
        print(f"weakfirst time {time.time() - start_time:.4f} seconds")        
        models[i].append(model)
        latent_rep[i].append(curr_latent)
        weights = torch.ones(training[i].shape[0])
        weights = nn.functional.normalize(weights, p=2, dim=0)         
        err_rate = weight_loss(model, training[i], weights, edge_indices[i], training_labels[i], device)
        quality_vec[i].append(quality_update(err_rate))
        start_time = time.time()            
        for _ in range(layer_num - 1):
            model = Weaker_Middle(256, num_classes).to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=lr_middle, weight_decay=weight_decay_middle)
            model.train()
            training[i] = np.array(curr_feat) 
            for epoch in range(train_middle):
                train_loss_compute = compute_loss(model, training[i], edge_indices[i], device)
                training_softprob = train_loss_compute[0]
                training_softprob = training_softprob.clone().detach().requires_grad_(True)
                optimizer.zero_grad()
                training_loss = F.nll_loss(training_softprob, training_labels[i]) 
                training_loss.backward()
                optimizer.step()                    
                train_acc = model_acc(model, training[i], edge_indices[i], training_labels[i], device)
                curr_feat = train_loss_compute[1]
            models[i].append(model)
            curr_latent = np.array(train_loss_compute[1])
            latent_rep[i].append(curr_latent)
            err_rate = weight_loss(model, training[i], weights, edge_indices[i], training_labels[i], device)
            quality_vec[i].append(quality_update(err_rate))             
            weights = weight_update(model, err_rate, training[i], weights, edge_indices[i], training_labels[i], device)                
        final_emb[i] = torch.zeros_like(torch.tensor(latent_rep[i][0])).to(device)
        for j in range(layer_num):
            final_emb[i] += quality_vec[i][j] * torch.tensor(latent_rep[i][j]).to(device)

        # Create training dataset
        train_tensor = torch.tensor(training_labels[i], dtype=torch.float32).view(-1, 1)
        dataset = TensorDataset(final_emb[i], train_tensor)
        train_loader = TorchDataLoader(dataset, batch_size=batch1, shuffle=True)        
        test_latent_rep = []         
        for j in range(layer_num):
            model = models[i][j]
            test_loss_compute = compute_loss(model, testing[i], edge_indices[i], device)
            test_latent_rep.append(np.array(test_loss_compute[1]))
            testing[i] = np.array(test_loss_compute[1])
        test_emb[i] = torch.zeros_like(torch.tensor(test_latent_rep[0])).to(device)          
        for j in range(layer_num):
            test_emb[i] += quality_vec[i][j] * torch.tensor(test_latent_rep[j]).to(device)
        
        # Create test dataset
        test_tensor = torch.tensor(testing_labels[i], dtype=torch.float32).view(-1, 1)
        test_dataset = TensorDataset(test_emb[i], test_tensor)
        test_loader = TorchDataLoader(test_dataset, batch_size=batch1, shuffle=True)

    train_subgraphs = []
    test_subgraphs = []
    train_subgraphs = []
    test_subgraphs = []
    for i in range(len(datasets)):
        encoded_train_data_list = []
        encoded_test_data_list = []
    
        # Encode Train Dataset
        for sample in range(len(train_dataset[i])):  # Iterate over Data objects
            node_features = final_emb[i][sample]  # Ensure shape [num_nodes, num_features]
            encoded_data = Data(x=node_features, edge_index=train_dataset[i][sample].edge_index)
            encoded_train_data_list.append(encoded_data)
    
        # Encode Test Dataset
        for sample in range(len(test_datasett[i])):  # Iterate over Data objects
            node_features = test_emb[i][sample]  # Ensure shape [num_nodes, num_features]
            encoded_data = Data(x=node_features, edge_index=test_datasett[i][sample].edge_index)
            encoded_test_data_list.append(encoded_data)


        attention_scores = torch.zeros_like(models[i][0].attention_scores)
        
        # Iterate over each model layer
        for j in range(layer_num):
            # Retrieve the attention scores from the current model
            current_attention_scores = models[i][j].attention_scores
            # Retrieve the quality score for the current model
            quality_score = quality_vec[i][j]                
            # Accumulate the weighted attention scores
            attention_scores += quality_score * current_attention_scores
        node_attention_scores_train = compute_node_attention_scores(
            attention_scores, train_dataset[i][0].edge_index, train_dataset[i][0].x.shape[0]
        )
    
        # Construct subgraphs for training set
        train_subgraphs.append(
            construct_subgraphs(node_attention_scores_train, train_dataset[i], encoded_train_data_list, threshold=threshold1)
        )
    
        # Compute node attention scores from extracted attention weights (Test)
        node_attention_scores_test = compute_node_attention_scores(
            attention_scores, test_datasett[i][0].edge_index, test_datasett[i][0].x.shape[0]
        )

        # Construct subgraphs for test set
        test_subgraphs.append(
            construct_subgraphs(node_attention_scores_test, test_datasett[i], encoded_test_data_list, threshold=threshold1)
        )
    

    train_meta_graphs = construct_meta_graphs(train_subgraphs)
    test_meta_graphs = construct_meta_graphs(test_subgraphs)       
    
    # Assign labels to meta-graphs explicitly
    for i, graph in enumerate(train_meta_graphs):
        graph.y = torch.tensor(training_labels[0][i], dtype=torch.long)  # Assign correct labels
    
    for i, graph in enumerate(test_meta_graphs):
        graph.y = torch.tensor(testing_labels[0][i], dtype=torch.long)  # Assign test labels
    


    for graph in train_meta_graphs:
        if graph.y.dim() == 0:  # If the label is scalar
            graph.y = graph.y.unsqueeze(0)  # Convert to tensor of shape [1]
    
    for graph in test_meta_graphs:
        if graph.y.dim() == 0:
            graph.y = graph.y.unsqueeze(0)
    
    from torch_geometric.loader import DataLoader  # Use PyG's DataLoader

    train_loader = DataLoader(train_meta_graphs, batch_size=best_params["batch_size"], shuffle=True)
    test_loader = DataLoader(test_meta_graphs, batch_size=best_params["batch_size"], shuffle=False)

    # Initialize model
    model = GraphLevelPredictor(
        input_dim=128,
        hidden_dim=32,  # Adjust if tunable
        output_dim=num_classes,
        dropout=best_params["dropout1"]
    ).to(device)

    set_weights(model)
    optimizer = torch.optim.Adam(
        model.parameters(), 
        lr=best_params["lr"], 
        weight_decay=1e-4
    )
    criterion = nn.CrossEntropyLoss()

    # Train on train+val
    model.train()
    for epoch in range(best_params["num_train"]):
        total_loss = 0
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(batch)
            loss = criterion(out, batch.y.view(-1).long())
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

    # Final test evaluation
    model.eval()
    correct, total = 0, 0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in test_loader:
            out = model(batch)
            preds = torch.argmax(out, dim=1)
            correct += (preds == batch.y.view(-1)).sum().item()
            total += batch.y.size(0)
            all_preds.extend(preds.tolist())
            all_labels.extend(batch.y.view(-1).tolist())

    final_test_accuracy = correct / total
    print(f"\n✅ Final Test Accuracy (on held-out fold): {final_test_accuracy:.4f}")
    return final_test_accuracy, all_preds, all_labels