#!/usr/bin/env python3
"""
MNIST CVAE Synthetic Retraining — Filtered pipeline

- Trains an initial CVAE on a small balanced real subset.
- Runs K rounds of synthetic retraining with a verifier (discriminator) that keeps the top-q synthetic samples.
- Adds a flag to control the verifier's training sample size (independent of other sizes).
- Logs validation metrics + FID, saves checkpoints and per-round sample grids.
"""

import os
import sys
import argparse
import random
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, TensorDataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt


# -----------------------------
# Utilities
# -----------------------------
def set_seed(seed: int):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)

def build_size_schedule_constant(constant_size: int, rounds: int) -> List[int]:
    return [int(constant_size)] * int(rounds)

def build_size_schedule_linear(start: int, end: int, rounds: int) -> List[int]:
    arr = (np.linspace(start, end, rounds) / 10).round().astype(int) * 10
    return arr.astype(int).tolist()

def parse_explicit_schedule(s: str, rounds: int) -> List[int]:
    xs = [int(x) for x in s.split(",") if x.strip()]
    if len(xs) != rounds:
        raise ValueError(f"--explicit-schedule length ({len(xs)}) must match --rounds ({rounds}).")
    return xs

def append_result(csv_path: str, row: dict):
    ensure_dir(os.path.dirname(csv_path))
    header = not os.path.exists(csv_path)
    pd.DataFrame([row]).to_csv(csv_path, mode="a", header=header, index=False)

@torch.no_grad()
def plot_model_samples(model, save_path: str, latent_dim: int = 20, num_classes: int = 10,
                       per_class: int = 8, device: str = None):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device).eval()
    z = torch.randn(num_classes * per_class, latent_dim, device=device)
    y = torch.arange(num_classes, device=device).repeat_interleave(per_class)
    y_onehot = F.one_hot(y, num_classes=num_classes).float()
    logits_flat = model.decoder.decode(z, y_onehot)
    imgs = torch.sigmoid(logits_flat).view(-1, 1, 28, 28).detach().cpu().numpy()

    fig, axes = plt.subplots(num_classes, per_class, figsize=(2*per_class, 2*num_classes))
    for c in range(num_classes):
        for j in range(per_class):
            idx = c * per_class + j
            axes[c, j].imshow(imgs[idx].squeeze(), cmap='gray')
            axes[c, j].axis('off')
            if j == 0:
                axes[c, j].set_ylabel(f"Class {c}", fontsize=10)
    plt.tight_layout()
    ensure_dir(os.path.dirname(save_path) or ".")
    plt.savefig(save_path, dpi=150)
    plt.close(fig)

@torch.no_grad()
def _generate_images_in_batches(model, total_samples: int, latent_dim: int, num_classes: int,
                                batch_size: int, device: str):
    model.eval()
    labels_full = torch.arange(total_samples) % num_classes
    imgs_all, labels_all = [], []
    for start in range(0, total_samples, batch_size):
        end = min(start + batch_size, total_samples)
        n = end - start
        z = torch.randn(n, latent_dim, device=device)
        y = labels_full[start:end].to(device)
        y_onehot = F.one_hot(y, num_classes=num_classes).float()
        logits_flat = model.decoder.decode(z, y_onehot)     # (n, 784) logits
        imgs = torch.sigmoid(logits_flat).view(-1, 1, 28, 28).cpu()
        imgs_all.append(imgs)
        labels_all.append(y.cpu())
    images = torch.cat(imgs_all, dim=0)
    labels = torch.cat(labels_all, dim=0)
    return images, labels

def compute_fid(model, fid_gen_size: int, device: str = "cuda"):
    from FID import calculate_fid_score
    images, labels = _generate_images_in_batches(
        model, total_samples=fid_gen_size, latent_dim=20, num_classes=10,
        batch_size=min(10000, fid_gen_size), device=device
    )
    real_ds = datasets.MNIST(root="./data", train=False, download=True, transform=transforms.ToTensor())
    synthetic_ds = TensorDataset(images, labels)
    return calculate_fid_score(real_ds, synthetic_ds)


# -----------------------------
# Main
# -----------------------------
def main():
    parser = argparse.ArgumentParser(description="MNIST CVAE retraining — filtered pipeline")
    # Project paths
    parser.add_argument("--repo-path", type=str, default="./MNIST/conv_cvae",
                        help="Path containing FID.py, models.py, train_helper.py, utils.py, data_helper.py")
    parser.add_argument("--out-root", type=str, default="./outputs_filtered",
                        help="Root dir for models/results/images.")

    # Experiment config
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    parser.add_argument("--init-size", type=int, default=500, help="Balanced real subset size.")
    parser.add_argument("--rounds", type=int, default=40)
    parser.add_argument("--threshold", type=float, default=0.10, help="Top-q kept (0<q<=1).")

    # Schedule config (choose ONE of the following ways)
    parser.add_argument("--constant-size", type=int, default=None, help="If set, uses a constant kept size each round.")
    parser.add_argument("--schedule-start", type=int, default=None, help="Linear schedule start (if constant not set).")
    parser.add_argument("--schedule-end", type=int, default=None, help="Linear schedule end (if constant not set).")
    parser.add_argument("--explicit-schedule", type=str, default="",
                        help="Comma-separated list overriding schedule (must match --rounds).")

    # Training hyperparams
    parser.add_argument("--epochs", type=int, default=200)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--patience", type=int, default=5)
    parser.add_argument("--batch-size", type=int, default=128)

    # Verifier (discriminator) training size (NEW)
    parser.add_argument("--verifier-train-size", type=int, default=None,
                        help="Number of samples to train the discriminator each round. "
                             "If None, defaults to min(5000, 10 * batch_size).")

    # FID
    parser.add_argument("--fid-gen-size", type=int, default=6000)

    # Housekeeping
    parser.add_argument("--keep-temp", action="store_true", help="Keep per-round synthetic dirs.")

    args = parser.parse_args()

    # Project helpers
    sys.path.append(args.repo_path)
    import models as models
    import train_helper as train_helper
    import utils as utils
    import data_helper as data_helper

    # Repro & device
    set_seed(args.seed)
    device = torch.device(args.device)

    # Outputs
    model_dir   = os.path.join(args.out_root, "model_saved")
    data_dir    = os.path.join(args.out_root, "data_saved")
    result_dir  = os.path.join(args.out_root, "results_saved")
    image_dir   = os.path.join(args.out_root, "images")
    for d in (model_dir, data_dir, result_dir, image_dir):
        ensure_dir(d)

    # Size schedule
    if args.explicit_schedule.strip():
        size_schedule = parse_explicit_schedule(args.explicit_schedule, args.rounds)
    elif args.constant_size is not None:
        size_schedule = build_size_schedule_constant(args.constant_size, args.rounds)
    else:
        # default linear if provided, else fallback to 10k→256k
        start = args.schedule_start if args.schedule_start is not None else 10000
        end   = args.schedule_end   if args.schedule_end   is not None else 256000
        size_schedule = build_size_schedule_linear(start, end, args.rounds)

        # Data
    full_train = datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())
    test_ds    = datasets.MNIST(root="./data", train=False, download=True, transform=transforms.ToTensor())
    test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False)
    
    # Verifier (discriminator) training size
    if args.verifier_train_size is None:
        VERIFIER_TRAIN_SIZE = len(full_train)   
    else:
        VERIFIER_TRAIN_SIZE = int(args.verifier_train_size)

    # Initial balanced subset
    idx_by_digit = utils.create_balanced_subset_indices(full_train, seed=args.seed)
    init_idx = utils.get_balanced_subset(idx_by_digit, args.init_size)
    init_subset = Subset(full_train, init_idx)
    init_loader = DataLoader(init_subset, batch_size=args.batch_size, shuffle=True)

    # Train initial model
    base_name  = f"cvae_conv_real_{args.init_size}"
    curr_model = models.CVAE(input_dim=784, label_dim=10, latent_dim=20, name=base_name, arch="conv").to(device)
    train_helper.train_model(curr_model, init_loader, device,
                             epochs=args.epochs, lr=args.lr, patience=args.patience, verbose=False)
    plot_model_samples(curr_model, os.path.join(image_dir, f"initial_samples_{args.init_size}.png"), device=str(device))

    # Initial eval + FID
    val_loss, val_recon, val_kl = train_helper.calculate_validation_loss(curr_model, test_loader, device)
    fid0 = compute_fid(curr_model, fid_gen_size=args.fid_gen_size, device=str(device))
    csv_path = os.path.join(result_dir, f"results_table_init{args.init_size}_K{args.rounds}.csv")
    append_result(csv_path, dict(round=0, model_name=base_name,
                                 val_loss=float(val_loss), val_recon=float(val_recon),
                                 val_kl=float(val_kl), fid=float(fid0)))
    utils.save_model(curr_model, getattr(curr_model, "get_name", lambda: base_name)(), model_dir)
    print(f"[Init] size={args.init_size}  val_loss={val_loss:.4f}  KL={val_kl:.4f}  Recon={val_recon:.4f}  FID={fid0:.4f}")

    # -----------------------------
    # Iterative filtered retraining
    # -----------------------------
    for r in range(1, args.rounds + 1):
        kept = int(size_schedule[r - 1])
        print(f"\n[Filtered] Round {r}/{args.rounds}  kept={kept}  q={args.threshold}  verifier_train_size={VERIFIER_TRAIN_SIZE}")

        # (A) Train discriminator on a controlled number of samples
        disc_ds = data_helper.prepare_discriminator_dataset(
            full_train, curr_model, num_samples=VERIFIER_TRAIN_SIZE, device=device
        )
        disc_loader = DataLoader(disc_ds, batch_size=args.batch_size, shuffle=True)
        disc_model = models.SyntheticDiscriminator(input_dim=784).to(device)
        train_helper.train_model(disc_model, disc_loader, device,
                                 epochs=80, lr=args.lr, patience=args.patience, verbose=False)
        del disc_loader, disc_ds

        # (B) Generate filtered synthetic dataset to a temp dir
        round_name = f"filtered_init{args.init_size}_q{args.threshold}_s{kept}_r{r}"
        tmp_dir = os.path.join(data_dir, round_name)
        print(f"[Filtered] Generating filtered synthetic data -> {tmp_dir}")
        data_helper.generate_balanced_images_with_filtering(
            model=curr_model,
            save_directory=tmp_dir,
            total_samples=kept,
            discriminator=disc_model,
            selection_threshold=args.threshold,
            verbose=False,
            use_quantile_filtering=True
        )

        # (C) Train CVAE on filtered synthetic
        syn_loader = data_helper.create_directory_based_dataloader(tmp_dir, batch_size=args.batch_size)
        syn_model = models.CVAE(input_dim=784, label_dim=10, latent_dim=20,
                                name=round_name, arch="conv").to(device)
        train_helper.train_model(syn_model, syn_loader, device,
                                 epochs=args.epochs, lr=args.lr, patience=args.patience, verbose=False)
        plot_model_samples(syn_model, os.path.join(image_dir, f"round{r}_samples.png"), device=str(device))

        # (D) Eval + FID, log, save
        vloss, vrecon, vkl = train_helper.calculate_validation_loss(syn_model, test_loader, device)
        fid_score = compute_fid(syn_model, fid_gen_size=args.fid_gen_size, device=str(device))
        append_result(csv_path, dict(round=r, model_name=round_name,
                                     val_loss=float(vloss), val_recon=float(vrecon),
                                     val_kl=float(vkl), fid=float(fid_score)))
        utils.save_model(syn_model, getattr(syn_model, "get_name", lambda: round_name)(), model_dir)
        print(f"[Filtered][Round {r}] val_loss={vloss:.4f}  KL={vkl:.4f}  Recon={vrecon:.4f}  FID={fid_score:.4f}")

        # (E) Advance chain
        curr_model = syn_model

        # (F) Cleanup temp dir
        if not args.keep_temp:
            import shutil
            try:
                if os.path.exists(tmp_dir):
                    shutil.rmtree(tmp_dir)
                    print(f"[CLEAN] Removed temp dir: {tmp_dir}")
            except Exception as e:
                print(f"[WARN] Failed to remove {tmp_dir}: {e}")

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print("\nDone. Outputs under:", os.path.abspath(args.out_root))


if __name__ == "__main__":
    main()
