import pandas as pd
import numpy as np
import re
import h5py

import torch
from torch.utils.data import Dataset, DataLoader, Sampler
import random

class PocketDataset(Dataset):
  def __init__(self, h5_path, poc_ids=None):
    self.h5_path = h5_path
    self.h5f = h5py.File(h5_path, 'r')
    self.emb_group = self.h5f['embeddings']
    self.label_group = self.h5f['labels']
    # All poc_ids
    self.poc_ids = poc_ids if poc_ids is not None else list(self.emb_group.keys())
    # Positive indices and Unlabeled indices (aligned to poc_ids order)
    self.pos_idx = []
    self.unlabeled_idx = []
    for i, poc_id in enumerate(self.poc_ids):
      label_arr = self.label_group[poc_id][()]
      if label_arr.sum() > 0:
        self.pos_idx.append(i)
      else:
        self.unlabeled_idx.append(i)
  
  def __len__(self):
    return len(self.poc_ids)
  
  def __getitem__(self, idx):
    poc_id = self.poc_ids[idx]
    embeddings = torch.tensor(self.emb_group[poc_id][()], dtype=torch.float32)  # Convert to tensor
    labels = torch.tensor(self.label_group[poc_id][()], dtype=torch.float32)    # Convert to tensor
    return embeddings, labels
  
  def close(self):
    self.h5f.close()

# handle variable-length sequences in a batch
def collate_fn(batch):
  embeddings, labels = zip(*batch)
  # Pad sequences to the maximum length in the batch, and recors masks
  max_len = max(emb.shape[0] for emb in embeddings)
  hidden_dim = embeddings[0].shape[1]

  # padding
  padded_embeddings = torch.zeros((len(embeddings), max_len, hidden_dim), dtype=torch.float32)
  masks = torch.zeros((len(embeddings), max_len)) # mask, 1 for valid, 0 for padded

  for i, emb in enumerate(embeddings):
    length = emb.shape[0]
    padded_embeddings[i, :length, :] = emb
    masks[i, :length] = 1
    
  labels = torch.stack(labels)
  return padded_embeddings, masks, labels

# Define BatchSample ensure each batch with at least one positive sample
class PUSampler(Sampler):
  def __init__(self, dataset, batch_size, min_pos=1, max_pos=None, drop_last=False, pos_randomness=0.3):
    self.dataset = dataset
    self.batch_size = batch_size
    self.min_pos = max(1, min_pos)
    self.max_pos = max_pos if max_pos is not None else batch_size - 1
    self.drop_last = drop_last
    self.pos_randomness = pos_randomness
  
  def __iter__(self):
    num_samples = len(self.dataset)
    if num_samples == 0:
      return iter([])

    pos_indices = self.dataset.pos_idx.copy()
    unl_indices = self.dataset.unlabeled_idx.copy()
    
    if len(pos_indices) == 0:
      indices = list(range(num_samples))
      random.shuffle(indices)
      for i in range(0, len(indices), self.batch_size):
        batch = indices[i:i + self.batch_size]
        if len(batch) < self.batch_size and self.drop_last:
          continue
        yield batch
      return

    total_batches = (num_samples // self.batch_size) if self.drop_last else ((num_samples + self.batch_size - 1) // self.batch_size)
    
    base_pos_per_batch = len(pos_indices) // total_batches
    remaining_pos = len(pos_indices) % total_batches
    
    batch_pos_counts = []
    for b in range(total_batches):
      pos_count = base_pos_per_batch
      
      if b < remaining_pos:
        pos_count += 1
      
      if self.pos_randomness > 0:
        max_variation = max(1, int(pos_count * self.pos_randomness))
        variation = random.randint(-max_variation, max_variation)
        pos_count += variation
      
      pos_count = max(self.min_pos, min(pos_count, self.max_pos))
      batch_pos_counts.append(pos_count)
    

    total_allocated = sum(batch_pos_counts)
    if total_allocated != len(pos_indices):
      diff = len(pos_indices) - total_allocated
      if diff > 0:
        for _ in range(diff):
          batch_idx = random.randint(0, total_batches - 1)
          if batch_pos_counts[batch_idx] < self.max_pos:
            batch_pos_counts[batch_idx] += 1
      else:
        for _ in range(-diff):
          batch_idx = random.randint(0, total_batches - 1)
          if batch_pos_counts[batch_idx] > self.min_pos:
            batch_pos_counts[batch_idx] -= 1
    
    random.shuffle(pos_indices)
    random.shuffle(unl_indices)
    
    pos_iter = iter(pos_indices)
    pos_used = 0
    
    unl_iter = iter(unl_indices)
    unl_used = 0

    for b in range(total_batches):
      batch = []
      
      pos_chosen = []
      for _ in range(batch_pos_counts[b]):
        try:
          pos_chosen.append(next(pos_iter))
          pos_used += 1
        except StopIteration:
          pos_iter = iter(pos_indices)
          random.shuffle(pos_indices)
          pos_chosen.append(next(pos_iter))
          pos_used += 1
      
      batch.extend(pos_chosen)
      
      remaining = self.batch_size - len(batch)
      while remaining > 0:
        try:
          batch.append(next(unl_iter))
          unl_used += 1
          remaining -= 1
        except StopIteration:
          batch.append(random.choice(range(num_samples)))
          remaining -= 1
      
      if len(batch) < self.batch_size and self.drop_last:
        continue
      
      yield batch

  def __len__(self):
    if self.drop_last:
      return len(self.dataset) // self.batch_size
    return (len(self.dataset) + self.batch_size - 1) // self.batch_size