import random
import torch
from utils import to_device
from torch_geometric.data import Batch, Data 

def apply_virtual_knockdown(graph: Data, knockdown_gene_idx: int) -> Data:
    x = graph.x.clone()
    x[knockdown_gene_idx] = 0
    mask = (graph.edge_index[0] != knockdown_gene_idx) & (graph.edge_index[1] != knockdown_gene_idx)
    edge_index = graph.edge_index[:, mask]
    edge_attr = graph.edge_attr[mask] if graph.edge_attr is not None else None
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

class TeacherSampler:
    def __init__(self, lincs_graphs, sample_ids_list, kd_meta, model, device, subsample_size):
        """
        lincs_graphs: list of Data objects for LINCS KD graphs
        sample_ids_list: list of sample IDs in same order as lincs_graphs
        kd_meta: dict mapping kd_gene -> list of sample_id strings
        model: the shared GNN encoder + head
        device: torch.device
        subsample_size: int, size of subset C for denominator approximation
        """
        self.graphs = lincs_graphs
        self.sample_ids = sample_ids_list
        self.kd_meta = kd_meta
        self.model = model
        self.device = device
        self.subsample_size = subsample_size

    def sample(self, tau_sim: float):
        kd_list = list(self.kd_meta.keys())
        V = len(kd_list)

        # ① Uniformly sample a, b from all KD genes
        a = random.choice(kd_list)
        b = random.choice(kd_list)

        # ② Uniformly sample subset C for denominator approximation
        K = min(V, self.subsample_size)
        cand = random.sample(kd_list, K)

        # ③ Create teacher graph set
        H_graphs = []
        # a, b
        idx_a = self.sample_ids.index(random.choice(self.kd_meta[a]))
        idx_b = self.sample_ids.index(random.choice(self.kd_meta[b]))
        H_graphs.append(to_device(self.graphs[idx_a], self.device))
        H_graphs.append(to_device(self.graphs[idx_b], self.device))
        # Candidate set cand
        for c in cand:
            idx_c = self.sample_ids.index(random.choice(self.kd_meta[c]))
            H_graphs.append(to_device(self.graphs[idx_c], self.device))

        # ④ Batch encode all graphs
        batch = Batch.from_data_list(H_graphs)
        batch = to_device(batch, self.device)
        Z_all = self.model(batch.x, batch.edge_index, batch.edge_attr)
        sizes = [g.num_nodes for g in H_graphs]
        splits = torch.split(Z_all, sizes, dim=0)
        Za, Zb, *Zc_list = splits
        Zc = torch.stack(Zc_list, dim=0)  # [K, n, d]

        # ⑤ Frobenius inner product → log_p and weight
        n = Za.size(0)
        sim_ab = torch.einsum('nd,nd->', Za, Zb) / n       # scalar
        sims_c = torch.einsum('nd,knd->k', Za, Zc) / n     # [K]
        log_num   = sim_ab / tau_sim
        log_denom = torch.logsumexp(sims_c / tau_sim, dim=0) + torch.log(torch.tensor(V/K, device=Za.device))
        log_p     = log_num - log_denom                     # Tensor scalar
        weight    = torch.exp(log_p).detach()               # Tensor scalar, grad off

        return a, b, cand, log_p, weight