# gnn-meta-graph/src/lfp_dataset.py
import os
import numpy as np
import torch
from torch_geometric.data import Data

"""
dataset.py

Data loading and preprocessing utilities for the GNN-based LFP classification project.
Includes functions to:
- Load raw LFP, spike, and trial data
- Balance classes across datasets
- Construct edge indices based on node correlations
- Create PyTorch Geometric Data objects
"""


np.random.seed(1)
torch.manual_seed(1)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)

def load_raw_data(dir_data, rat_name):
    trlPath = os.path.join(dir_data, rat_name, rat_name.lower() + '_trial_info.npy')
    spkPath = os.path.join(dir_data, rat_name, rat_name.lower() + '_spike_data_binned.npy')
    lfpPath = os.path.join(dir_data, rat_name, rat_name.lower() + '_lfp_data_sampled.npy')

    trial_info = np.load(trlPath)
    spike_data = np.load(spkPath)
    lfp_data = np.load(lfpPath)
    lfp_data = np.swapaxes(lfp_data, 1, 2)
    return trial_info, spike_data, lfp_data

def count_labels(trial_info, target_col=3):
    labels = trial_info[:, target_col] - 1
    unique, counts = np.unique(labels, return_counts=True)
    return dict(zip(unique, counts))

def select_balanced_trials(trial_info, spike_data, lfp_data, target_col=3, n_labels=None):
    if n_labels is None:
        n_labels = {label: 10 for label in range(5)}

    labels = trial_info[:, target_col] - 1
    selected_indices = []
    for label, n_label in n_labels.items():
        indices_label = np.where(labels == label)[0]
        n_label = min(len(indices_label), n_label)
        selected = np.random.choice(indices_label, n_label, replace=False)
        selected_indices.extend(selected)

    selected_indices = np.sort(selected_indices)
    return trial_info[selected_indices], spike_data[selected_indices], lfp_data[selected_indices]

def create_edge_index_from_correlation(data, threshold=0.7):
    data_np = data.numpy()
    num_nodes = data_np.shape[0]
    corr = np.corrcoef(data_np.T)
    indices = np.where(np.abs(corr) > threshold)
    valid = indices[0] != indices[1]
    src = indices[0][valid]
    tgt = indices[1][valid]
    mask = (src < num_nodes) & (tgt < num_nodes)
    edge_index = torch.tensor([src[mask], tgt[mask]], dtype=torch.long)
    return edge_index

def prepare_gnn_dataset(lfp_data, trial_info, target_col=3, corr_threshold=0.5):
    dataset = []
    for i in range(lfp_data.shape[0]):
        x = torch.tensor(lfp_data[i], dtype=torch.float)
        edge_index = create_edge_index_from_correlation(x, threshold=corr_threshold)
        y = torch.tensor(trial_info[i, target_col] - 1, dtype=torch.long)
        dataset.append(Data(x=x, edge_index=edge_index, y=y))
    return dataset

def prepare_all_datasets(data_dir="lfp_data"):
    rat_names = ["superchris", "barat", "stella", "mitt", "buchanan"]
    raw_data = {}
    for rat in rat_names:
        trial_info, spike_data, lfp_data = load_raw_data("lfp_data", rat)
        raw_data[rat] = (trial_info, spike_data, lfp_data)

    counts = {rat: count_labels(raw_data[rat][0]) for rat in rat_names}
    min_label_counts = {label: min(counts[rat].get(label, 0) for rat in rat_names) for label in range(5)}

    balanced = {}
    for rat in rat_names:
        tinfo, spk, lfp = raw_data[rat]
        t_bal, s_bal, l_bal = select_balanced_trials(tinfo, spk, lfp, n_labels=min_label_counts)
        balanced[rat] = {"trial_info": t_bal, "spike_data": s_bal, "lfp_data": l_bal}

    gnn_datasets = {}
    for i, rat in enumerate(rat_names):
        t_bal = balanced[rat]["trial_info"]
        l_bal = balanced[rat]["lfp_data"]
        gnn_datasets[i] = prepare_gnn_dataset(l_bal, t_bal)

    return gnn_datasets