"""
Constrained generative adversarial networks (GAN-C)

Chao, Xiaopeng, et al.
"Constrained generative adversarial networks."
IEEE Access 9 (2021): 19208-19218.
"""

import torch
import torch.optim as optim
import os
import time
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from dataset import get_dataloader
import logging
import swanlab
import hydra
from omegaconf import DictConfig, OmegaConf
from utils import set_seeds, save_model, evaluate_metrics, save_image_grid, ModelEMA, DiffAugment
import torch.nn.functional as F


log = logging.getLogger(__name__)


@hydra.main(version_base=None, config_path="conf", config_name="config_ganc")
def train(cfg: DictConfig):
    # Print config
    log.info(OmegaConf.to_yaml(cfg))
    
    # Load config
    problem_cfg = cfg.general
    algo_cfg = cfg.algo
    set_seeds(problem_cfg.seed)

    swanlab.init(
        project="MinMaxCon-GAN-C",
        experiment_name=algo_cfg.name,
        group=problem_cfg.model + "-" + problem_cfg.dataset.upper() if cfg.swanlab_group == "auto" else cfg.swanlab_group,
        config=OmegaConf.to_container(cfg, resolve=True),
    )

    # Training config
    lr_D = algo_cfg.lr_D
    lr_G = algo_cfg.lr_G
    beta1 = algo_cfg.beta1
    beta2 = algo_cfg.beta2
    rho = algo_cfg.rho
    batch_size = problem_cfg.batch_size
    diff_aug_policy = problem_cfg.diff_aug if hasattr(problem_cfg, 'diff_aug') else ""
    ema_decay = problem_cfg.ema_decay
    dataset_name = problem_cfg.dataset
    model_type = problem_cfg.model

    # Setup Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    log.info(f"Using device: {device}")

    # Determine channels
    nc = problem_cfg.nc

    # Train loader
    # Get project root directory (parent of GANC directory)
    # If running from GANC directory, go up one level; if from project root, use current
    script_dir = os.path.dirname(os.path.abspath(__file__))
    project_root = os.path.dirname(script_dir)  # Go up from GANC/ to project root
    data_root = os.path.join(project_root, 'data')
    dataloader = get_dataloader(dataset_name, batch_size, img_size=tuple(problem_cfg.img_size), root=data_root, train=True, download=True, num_workers=4)
    # Test loader for FID
    test_dataloader = get_dataloader(dataset_name, batch_size, img_size=tuple(problem_cfg.img_size), root=data_root, train=False, download=True, num_workers=4)
    log.info(f"Data loader ready for {dataset_name}.")

    # Models
    if model_type.lower() == 'dcgan':
        log.info("Using DCGAN models.")
        from models.DCGAN import Discriminator, Generator
    elif model_type.lower() == 'sngan':
        log.info("Using SNGAN models.")
        from models.SNGAN import Discriminator, Generator
    else:
        raise ValueError(f"Unknown model type: {model_type}")

    discriminator = Discriminator(nc=nc, img_size=problem_cfg.img_size[0], ndf=problem_cfg.ndf).to(device)
    generator = Generator(z_dim=problem_cfg.z_dim, nc=nc, img_size=problem_cfg.img_size[0], ngf=problem_cfg.ngf).to(device)

    # FID Metric
    fid = FrechetInceptionDistance(normalize=True).to(device)
    inception_score = InceptionScore(normalize=True).to(device)

    # Optimizers
    optimizer_D = optim.Adam(discriminator.parameters(), lr=lr_D, betas=(beta1, beta2))
    optimizer_G = optim.Adam(generator.parameters(), lr=lr_G, betas=(beta1, beta2))

    # Fixed noise for visualization
    fixed_z = torch.randn(problem_cfg.fixed_noise_samples, problem_cfg.z_dim, device=device)

    # EMA
    ema_G = ModelEMA(generator, decay=ema_decay)

    # Initial Rigorous Evaluation
    log.info("Performing initial evaluation...")
    metrics = evaluate_metrics(ema_G.shadow, test_dataloader, batch_size, problem_cfg.z_dim, fid=fid, inception_score=inception_score, num_samples=None)
    log.info(f"Initial FID: {metrics['fid']:.3f}, IS: {metrics['is_mean']:.3f}±{metrics['is_std']:.3f}")
    swanlab.log({
        "eval/fid": metrics['fid'],
        "eval/is_mean": metrics['is_mean'],
        "eval/is_std": metrics['is_std']
    }, step=0)

    # Training Loop
    log.info("Starting training...")
    start_time = time.perf_counter()
    
    for epoch in range(problem_cfg.epochs):
        is_eval_epoch = (epoch + 1) % problem_cfg.log_epoch_interval == 0
        is_sampling_epoch = (epoch + 1) % problem_cfg.sampling_epoch_interval == 0

        # ---- Training ----
        discriminator.train()
        generator.train()
        for batch_idx, (data, _) in enumerate(dataloader):
            real_img = DiffAugment(data.to(device), policy=diff_aug_policy)
            current_batch_size = real_img.size(0)
            
            # ---- Train Discriminator ----
            discriminator.zero_grad()
            
            z = torch.randn(current_batch_size, problem_cfg.z_dim, device=device)
            fake_img = generator(z)
            fake_img = [DiffAugment(img, policy=diff_aug_policy) for img in fake_img] if isinstance(fake_img, list) else DiffAugment(fake_img, policy=diff_aug_policy)
            fake_img_detach = [img.detach() for img in fake_img] if isinstance(fake_img, list) else fake_img.detach()
            
            # Standard GAN training
            _, real_logits = discriminator(real_img)
            _, fake_logits = discriminator(fake_img_detach) # detach fake_img so we don't backprop through G when training D
            
            # h term: 0.5 * rho * mean((log(real) - log(fake))^2)
            # d_loss = - (mean(log(real) + log(1 - fake))) + h
            # F.logsigmoid(-x) = log(1.0 - sigmoid(x))
            h_loss = torch.mean((F.logsigmoid(real_logits) - F.logsigmoid(fake_logits)) ** 2)
            d_gan_loss = - torch.mean(F.logsigmoid(real_logits) + F.logsigmoid(-fake_logits))
            d_loss = d_gan_loss + (0.5 * rho) * h_loss
            
            d_loss.backward()
            optimizer_D.step()
            
            # ---- Train Generator ----
            generator.zero_grad()
            
            _, fake_logits_g = discriminator(fake_img) # fake_img graph is still valid for G backprop
            
            g_gan_loss = torch.mean(F.logsigmoid(real_logits.detach()) + F.logsigmoid(-fake_logits_g))
            g_loss = g_gan_loss

            g_loss.backward()
            optimizer_G.step()
            ema_G.update(generator)
            
        # ---- Logging ----
        if is_eval_epoch:
            # Metrics logging
            log_items = [
                f"epoch: {epoch + 1:3d}",
                f"d_loss: {- d_gan_loss.item():.3f}",
                f"g_loss: {g_gan_loss.item():.3f}",
                f"h_loss: {h_loss.item():.3f}",
            ]

            metrics = evaluate_metrics(ema_G.shadow, test_dataloader, batch_size, problem_cfg.z_dim, fid=fid, inception_score=inception_score, num_samples=None)
            log_items.extend([
                f"fid: {metrics['fid']:.3f}",
                f"is: {metrics['is_mean']:.3f}±{metrics['is_std']:.3f}",
            ])

            log.info(", ".join(log_items))

            swanlab.log({
                "train/d_loss": - d_gan_loss.item(),
                "train/g_loss": g_gan_loss.item(),
                "train/h_loss": h_loss.item(),
                "train/time": time.perf_counter() - start_time,
                "eval/fid": metrics['fid'],
                "eval/is_mean": metrics['is_mean'],
                "eval/is_std": metrics['is_std'],
            }, step=epoch + 1)

        # ---- Monitoring images with fixed noise ----
        if is_sampling_epoch:
            save_path = os.path.join(problem_cfg.out_dir, 'images', f"{epoch + 1}.png")
            save_image_grid(ema_G.shadow, fixed_z, save_path)
                        
    log.info("Training finished.")

    save_model(
        model=ema_G.shadow,
        model_save_dir=cfg.model_save_dir,
        dataset=problem_cfg.dataset,
        algo_name=algo_cfg.name,
        model_name=problem_cfg.model,
        seed=problem_cfg.seed,
        cfg=cfg,
        filename="generator.pt",
    )


if __name__ == "__main__":
    train()
