from cProfile import label
import torch
import numpy as np
import pandas as pd
import scipy.sparse as sp
import os

from collections import defaultdict


class Dataset(torch.utils.data.Dataset):
    def __init__(self, data_path, expression_data, num_neighbors=5):
        data = pd.read_csv(data_path, index_col=0, header=0)
        train_data = pd.read_csv(data_path, index_col=0).values
        self.train_set = train_data
        self.dataset = np.array(data.iloc[:, :2])
        label = np.array(data.iloc[:, -1])
        self.label = np.eye(2)[label]
        self.label = label
        self.expression_data = expression_data
       
        self.num_gene = expression_data.shape[0]
        self.num_neighbors = num_neighbors

        data_dir = os.path.dirname(data_path)
        gnn_pred_path = os.path.join(data_dir, 'gat_predictions.csv')
        try:
            gnn_preds_df = pd.read_csv(gnn_pred_path).set_index(['TFs', 'Targets'])
            self.gnn_scores_series = gnn_preds_df['Predicted Labels']
            print("Successfully loaded GAT prediction scores.")
        except FileNotFoundError:
            print(f"Warning: GNN prediction file not found at {gnn_pred_path}. Using default score 0.5.")
            self.gnn_scores_series = None

        self.pad_value = np.mean(expression_data, axis=0)

        self.gene_adj_dict = self._build_gene_adj_dict()

    def _build_gene_adj_dict(self):
        adj_dict = defaultdict(list)
        for pos in self.train_set:
            tf, target, label = pos
            if label == 1:
                adj_dict[tf].append(target)
                adj_dict[target].append(tf)
        return adj_dict

    def _get_related_pairs(self, gene_pair):
        g1, g2 = gene_pair
        related_pairs = []

        for neighbor in self.gene_adj_dict.get(g1, [])[:self.num_neighbors]:
            if neighbor != g2:
                related_pairs.append((g1, neighbor))

        for neighbor in self.gene_adj_dict.get(g2, [])[:self.num_neighbors]:
            if neighbor != g1:
                related_pairs.append((g2, neighbor))

        while len(related_pairs) < self.num_neighbors:
            related_pairs.append((g1, g2))

        return related_pairs[:self.num_neighbors]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        main_pair = self.dataset[i]
        main_label = self.label[i]

        if self.gnn_scores_series is not None:
            gnn_score = self.gnn_scores_series.get(tuple(main_pair), 0.5)
        else:
            gnn_score = 0.5

        g1_expr = np.expand_dims(self.expression_data[main_pair[0]], axis=0)
        g2_expr = np.expand_dims(self.expression_data[main_pair[1]], axis=0)
        main_expr = np.concatenate((g1_expr, g2_expr), axis=0)

        related_pairs = self._get_related_pairs(main_pair)
        related_exprs = []

        for pair in related_pairs:
            g1_expr = np.expand_dims(self.expression_data[pair[0]], axis=0)
            g2_expr = np.expand_dims(self.expression_data[pair[1]], axis=0)
            pair_expr = np.concatenate((g1_expr, g2_expr), axis=0)
            related_exprs.append(pair_expr)

        related_exprs = np.stack(related_exprs)

        return main_pair, main_expr, related_exprs, gnn_score, main_label

    def Adj_Generate(self, TF_set, direction=False, loop=False):
        adj = sp.dok_matrix((self.num_gene, self.num_gene), dtype=np.float32)

        for pos in self.train_set:
            tf = pos[0]
            target = pos[1]

            if direction == False:
                if pos[-1] == 1:
                    adj[tf, target] = 1.0
                    adj[target, tf] = 1.0
            else:
                if pos[-1] == 1:
                    adj[tf, target] = 1.0
                    if target in TF_set:
                        adj[target, tf] = 1.0

        if loop:
            adj = adj + sp.identity(self.num_gene)

        adj = adj.todok()
        return adj

def adj2saprse_tensor(adj):
    coo = adj.tocoo()
    i = torch.stack([torch.LongTensor(coo.row), torch.LongTensor(coo.col)])
    v = torch.FloatTensor(coo.data)
    return torch.sparse_coo_tensor(i, v, coo.shape)
