import argparse
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from sklearn.metrics import average_precision_score

from DataLoader import PocketDataset, PUSampler, collate_fn
from model_attn import PU_AttnPoolNet
from loss_es import pu_loss_ranking_multi, ranking_val_loss
from utils import estimate_prior


# ------------------------------
# Set random seeds
# ------------------------------
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# ------------------------------
# Single-run training
# ------------------------------
def train_one_run(run_id, args, device):
    batch_size = args.samples

    # Data paths
    train_h5_path = "./path_to/train_pockets.h5"
    val_h5_path = "./path_to/val_pockets.h5"

    # Model parameters
    n_features = 1280
    n_labels = 8

    print(f"\n=== Run {run_id} ===")
    print("Loading training dataset...")
    train_dataset = PocketDataset(train_h5_path)
    print(f"Train size: {len(train_dataset)}"); print(f"#Positives: {len(train_dataset.pos_idx)}"); print(f"#Unlabeled: {len(train_dataset.unlabeled_idx)}")

    print("Loading validation dataset...")
    val_dataset = PocketDataset(val_h5_path)
    print(f"Val size: {len(val_dataset)}"); print(f"#Positives: {len(val_dataset.pos_idx)}"); print(f"#Unlabeled: {len(val_dataset.unlabeled_idx)}")

    train_sampler = PUSampler(train_dataset, batch_size=batch_size, min_pos=1, pos_randomness=0.3)
    val_sampler = PUSampler(val_dataset, batch_size=batch_size, min_pos=1, pos_randomness=0.0)

    num_workers = 12
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_sampler=train_sampler, collate_fn=collate_fn,
        num_workers=num_workers, pin_memory=(device.type == 'cuda'), persistent_workers=True, prefetch_factor=2
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_sampler=val_sampler, collate_fn=collate_fn,
        num_workers=num_workers, pin_memory=(device.type == 'cuda'), persistent_workers=True, prefetch_factor=2
    )

    print(f"Num train batches: {len(train_loader)}, num val batches: {len(val_loader)}")

    # Initialize model
    model = PU_AttnPoolNet(n_features, 128, n_labels).to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

    # Estimate priors
    prior_per_label = estimate_prior(train_dataset, device)

    n_epochs = 1000
    patience = 15
    best_val_loss = float('inf')
    best_state_dict = None
    no_improve = 0

    checkpoint_path = f"./checkpoints/best_nnpu_attn_pooling_64bs_only_BioLiP2_run{run_id}.pt"
    os.makedirs("./checkpoints", exist_ok=True)

    print("Start training...")
    for epoch in range(n_epochs):
        model.train()
        train_losses = []
        for padded_embs, masks, labels in train_loader:
            padded_embs, masks, labels = padded_embs.to(device), masks.to(device), labels.to(device)
            optimizer.zero_grad()
            logits, _ = model(padded_embs, masks)
            loss = pu_loss_ranking_multi(logits, labels, prior=prior_per_label)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        # Validation
        model.eval()
        with torch.no_grad():
            all_logits, all_labels = [], []
            for padded_embs, masks, labels in val_loader:
                padded_embs, masks, labels = padded_embs.to(device), masks.to(device), labels.to(device)
                logits, _ = model(padded_embs, masks)
                all_logits.append(logits)
                all_labels.append(labels)
            all_logits = torch.cat(all_logits, dim=0)
            all_labels = torch.cat(all_labels, dim=0)
            mean_val_loss = ranking_val_loss(all_logits, all_labels)

        if (epoch + 1) % 1 == 0:
            print(f"Epoch {epoch+1:03d} | TrainLoss={np.mean(train_losses):.4f} | ValLoss={mean_val_loss:.4f}")

        # LR scheduling
        scheduler.step(mean_val_loss)

        # Early stopping
        if mean_val_loss + 1e-6 < best_val_loss:
            best_val_loss = mean_val_loss
            best_state_dict = {k: v.detach().clone() for k, v in model.state_dict().items()}
            no_improve = 0
            torch.save({
                'model_type': 'PU_Attn_onlyBioLiP2',
                'model_state_dict': best_state_dict,
                'val_loss': best_val_loss,
                'epoch': epoch + 1,
                'prior_per_label': prior_per_label.detach().cpu(),
            }, checkpoint_path)
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

    train_dataset.close()
    val_dataset.close()

    print(f"[Run {run_id}] Best val loss: {best_val_loss:.6f}")
    return best_val_loss


# ------------------------------
# Main entry (multiple runs)
# ------------------------------
def main():
    parser = argparse.ArgumentParser(description="Multi-run PU training")
    parser.add_argument("--samples", default=64, type=int, help="batch size")
    parser.add_argument("--runs", default=5, type=int, help="number of independent runs")
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    all_losses = []
    for run_id in range(1, args.runs + 1):
        set_seed(run_id * 100)  # different seed per run
        val_loss = train_one_run(run_id, args, device)
        all_losses.append(val_loss)

    all_losses = np.array(all_losses)
    print("\n=== Multi-run results ===")
    print("Validation losses:", all_losses)
    print(f"Mean loss: {all_losses.mean():.6f} ± {all_losses.std():.6f}")


if __name__ == "__main__":
    main()
