"""
Stochastic Algorithm for Couple Constrained Minimax optimization (SPACO)
"""

import torch
import torch.optim as optim
import os
import time
import copy
import math
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from dataset import get_dataloader
import logging
import swanlab
import hydra
from omegaconf import DictConfig, OmegaConf
from utils import set_seeds, save_model, evaluate_metrics, save_image_grid, ModelEMA, DiffAugment
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR


log = logging.getLogger(__name__)

STPES_RATIO = 10

class SPACOGeneratorScheduler:
    def __init__(self, optimizer, pilot_t, pilot_s, restart_interval):
        self.optimizer = optimizer
        # self.pilot_t = pilot_t
        # self.pilot_s = pilot_s
        self.restart_interval = restart_interval + 1
        self.lr_scheduler = LambdaLR(
            optimizer, 
            lr_lambda=lambda step: ((step % self.restart_interval) // STPES_RATIO + 1)  ** (-6 * pilot_t - pilot_s)
        )

    def step(self):
        # Update learning rate
        self.lr_scheduler.step()

    def get_lr(self):
        return [group['lr'] for group in self.optimizer.param_groups]

    def get_weight_decay(self):
        return [group.get('weight_decay', 0.0) for group in self.optimizer.param_groups]


class SPACODiscriminatorScheduler:
    def __init__(self, optimizer, pilot_t, pilot_s, restart_interval):
        self.optimizer = optimizer
        self.pilot_t = pilot_t
        self.pilot_s = pilot_s
        self.restart_interval = restart_interval + 1
        # Using LambdaLR for learning rate
        self.lr_scheduler = LambdaLR(
            optimizer, 
            lr_lambda=lambda step: ((step % self.restart_interval) // STPES_RATIO + 1) ** (-pilot_t - pilot_s)
        )
        
        # Store initial weight decays
        self.initial_wds = [group.get('weight_decay', 0.0) for group in optimizer.param_groups]

    def step(self):
        # Update learning rate
        self.lr_scheduler.step()
        
        # Update weight decay
        step = self.lr_scheduler.last_epoch
        k = (step % self.restart_interval) + 1
        wd_factor = k ** (-self.pilot_t)
        
        for i, group in enumerate(self.optimizer.param_groups):
            if self.initial_wds[i] > 0:
                group['weight_decay'] = self.initial_wds[i] * wd_factor

    def get_lr(self):
        return [group['lr'] for group in self.optimizer.param_groups]

    def get_weight_decay(self):
        return [group.get('weight_decay', 0.0) for group in self.optimizer.param_groups]


class SPACORhoScheduler:
    def __init__(self, rho_init, pilot_t, restart_interval):
        self.rho_init = rho_init
        self.pilot_t = pilot_t
        self.restart_interval = restart_interval + 1
        self.current_rho = rho_init
        self.step_count = 0

    def step(self):
        self.step_count += 1
        k = self.step_count // STPES_RATIO + 1
        self.current_rho = self.rho_init * k ** self.pilot_t

    def get_rho(self):
        return self.current_rho


def _compute_h_loss(real_logits, fake_logits):
    h_loss = torch.mean((F.logsigmoid(real_logits) - F.logsigmoid(fake_logits)) ** 2)
    return h_loss.detach()


class SPACOEtaScheduler:
    def __init__(self, eta0, pilot_s, restart_interval, total_epochs):
        self.eta_init = eta0
        self.pilot_s = pilot_s
        self.restart_interval = restart_interval + 1
        self.restart_number = total_epochs
        self.current_eta = eta0
        self.lower_bound = 0.5 # Empirically, avoid using a momentum with factor > 0.5 when training GANs
        self.eta_min = 0.0
        self.step_count = 0

    def step(self):
        self.step_count += 1
        k = (self.step_count % self.restart_interval) // STPES_RATIO + 1
        self.current_eta = self.eta_init * (k ** (-self.pilot_s))

        if self.step_count % self.restart_interval == 0:
            self.update_eta_min()

    def update_eta_min(self):
        progress = min(1.0, (self.step_count // self.restart_interval) / self.restart_number)
        center_progress = 0.3
        if progress < center_progress:
            virtual_progress = 0.5 * (progress / center_progress)
        else:
            virtual_progress = 0.5 + 0.5 * ((progress - center_progress) / (1.0 - center_progress))
        
        self.eta_min = self.lower_bound + (1.0 - self.lower_bound) * 0.5 * (1 - math.cos(math.pi * virtual_progress))

    def get_eta(self):
        return max(self.current_eta, self.eta_min)


@hydra.main(version_base=None, config_path="conf", config_name="config_spaco")
def train(cfg: DictConfig):
    # Print config
    log.info(OmegaConf.to_yaml(cfg))
    
    # Load config
    problem_cfg = cfg.general
    algo_cfg = cfg.algo
    set_seeds(problem_cfg.seed)

    swanlab.init(
        project="MinMaxCon-GAN-C",
        experiment_name=algo_cfg.name,
        group=problem_cfg.model + "-" + problem_cfg.dataset.upper() if cfg.swanlab_group == "auto" else cfg.swanlab_group,
        config=OmegaConf.to_container(cfg, resolve=True),
    )

    # Training config
    pilot_t = algo_cfg.pilot_t
    pilot_s = algo_cfg.pilot_s
    lr_decay = algo_cfg.lr_decay # boolean, whether to use learning rate decay
    rho_increase = algo_cfg.rho_increase # boolean, whether to increase rho during training
    
    lr_alpha = algo_cfg.lr_alpha
    lr_beta = algo_cfg.lr_beta
    prox_y = algo_cfg.prox_y
    rho_init = algo_cfg.rho

    use_storm = algo_cfg.use_storm
    storm_eta0 = algo_cfg.storm_eta0

    log.info(f"Storm settings: use_storm={use_storm}, storm_eta0={storm_eta0}")

    beta1 = algo_cfg.beta1
    beta2 = algo_cfg.beta2
    batch_size = problem_cfg.batch_size
    diff_aug_policy = problem_cfg.diff_aug if hasattr(problem_cfg, 'diff_aug') else ""
    ema_decay = problem_cfg.ema_decay
    dataset_name = problem_cfg.dataset
    model_type = problem_cfg.model

    # Setup Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    log.info(f"Using device: {device}")

    # Determine channels
    nc = problem_cfg.nc

    # Train loader
    # Get project root directory (parent of GANC directory)
    # If running from GANC directory, go up one level; if from project root, use current
    script_dir = os.path.dirname(os.path.abspath(__file__))
    project_root = os.path.dirname(script_dir)  # Go up from GANC/ to project root
    data_root = os.path.join(project_root, 'data')
    dataloader = get_dataloader(dataset_name, batch_size, img_size=tuple(problem_cfg.img_size), root=data_root, train=True, download=True, num_workers=4)
    # Test loader for FID
    test_dataloader = get_dataloader(dataset_name, batch_size, img_size=tuple(problem_cfg.img_size), root=data_root, train=False, download=True, num_workers=4)
    log.info(f"Data loader ready for {dataset_name}.")

    # Models
    if model_type.lower() == 'dcgan':
        log.info("Using DCGAN models.")
        from models.DCGAN import Discriminator, Generator
    elif model_type.lower() == 'sngan':
        log.info("Using SNGAN models.")
        from models.SNGAN import Discriminator, Generator
    else:
        raise ValueError(f"Unknown model type: {model_type}")

    discriminator = Discriminator(nc=nc, img_size=problem_cfg.img_size[0], ndf=problem_cfg.ndf).to(device)
    generator = Generator(z_dim=problem_cfg.z_dim, nc=nc, img_size=problem_cfg.img_size[0], ngf=problem_cfg.ngf).to(device)

    # FID Metric
    fid = FrechetInceptionDistance(normalize=True).to(device)
    inception_score = InceptionScore(normalize=True).to(device)

    # Optimizers
    optimizer_D = optim.Adam(discriminator.parameters(), lr=lr_beta, betas=(beta1, beta2), weight_decay=prox_y)
    optimizer_G = optim.Adam(generator.parameters(), lr=lr_alpha, betas=(beta1, beta2))

    # STORM for Generator
    steps_per_epoch = len(dataloader)
    
    if use_storm:
        log.info(f"Using STORM momentum for Generator (eta0={storm_eta0}, pilot_s={pilot_s})")
        generator_prev = copy.deepcopy(generator)
        generator_prev.to(device)
        generator_prev.train()
        storm_buffer = {name: torch.zeros_like(p) for name, p in generator.named_parameters()}
        storm_prev_state = {k: v.detach().clone() for k, v in generator.state_dict().items()}
        storm_has_prev = False
        eta_scheduler = SPACOEtaScheduler(storm_eta0, pilot_s, steps_per_epoch, problem_cfg.epochs)
    else:
        generator_prev = None
        storm_buffer = None
        storm_prev_state = None
        eta_scheduler = None

    # Schedulers
    scheduler_G = SPACOGeneratorScheduler(optimizer_G, pilot_t, pilot_s, steps_per_epoch)
    scheduler_D = SPACODiscriminatorScheduler(optimizer_D, pilot_t, pilot_s, steps_per_epoch)
    scheduler_rho = SPACORhoScheduler(rho_init, pilot_t, steps_per_epoch)

    # Fixed noise for visualization
    fixed_z = torch.randn(problem_cfg.fixed_noise_samples, problem_cfg.z_dim, device=device)

    # EMA
    ema_G = ModelEMA(generator, decay=ema_decay)

    # Initial Rigorous Evaluation
    log.info("Performing initial evaluation...")
    metrics = evaluate_metrics(ema_G.shadow, test_dataloader, batch_size, problem_cfg.z_dim, fid=fid, inception_score=inception_score, num_samples=None)
    log.info(f"Initial FID: {metrics['fid']:.3f}, IS: {metrics['is_mean']:.3f}±{metrics['is_std']:.3f}")
    swanlab.log({
        "eval/fid": metrics['fid'],
        "eval/is_mean": metrics['is_mean'],
        "eval/is_std": metrics['is_std']
    }, step=0)

    # Training Loop
    log.info("Starting training...")
    start_time = time.perf_counter()
    
    for epoch in range(problem_cfg.epochs):
        is_eval_epoch = (epoch + 1) % problem_cfg.log_epoch_interval == 0
        is_sampling_epoch = (epoch + 1) % problem_cfg.sampling_epoch_interval == 0

        # ---- Training ----
        discriminator.train()
        generator.train()
        for batch_idx, (data, _) in enumerate(dataloader):
            real_img = DiffAugment(data.to(device), policy=diff_aug_policy)
            current_batch_size = real_img.size(0)
            current_rho = scheduler_rho.get_rho()
            
            # ---- Train Discriminator ----
            discriminator.zero_grad()
            
            z = torch.randn(current_batch_size, problem_cfg.z_dim, device=device)
            fake_img = generator(z)
            fake_img = [DiffAugment(img, policy=diff_aug_policy) for img in fake_img] if isinstance(fake_img, list) else DiffAugment(fake_img, policy=diff_aug_policy)
            fake_img_detach = [img.detach() for img in fake_img] if isinstance(fake_img, list) else fake_img.detach()
            
            # Standard GAN training
            _, real_logits = discriminator(real_img)
            _, fake_logits = discriminator(fake_img_detach)

            # F.logsigmoid(-x) = log(1.0 - sigmoid(x))
            d_h_loss = _compute_h_loss(real_logits, fake_logits)
            d_gan_loss = - torch.mean(F.logsigmoid(real_logits) + F.logsigmoid(-fake_logits))
            d_loss = d_gan_loss + (0.5 * current_rho) * d_h_loss
            d_loss.backward()

            # ---- STORM Pre-calculation (Grad Prev) ----
            grad_prev_buffer = {}
            if use_storm and storm_has_prev:
                generator_prev.load_state_dict(storm_prev_state)
                generator_prev.zero_grad(set_to_none=True)
                
                # Re-use z from current batch, and IMPORTANTLY: use the OLD discriminator (y^k) before step()
                fake_img_prev = generator_prev(z)
                fake_img_prev = [DiffAugment(img, policy=diff_aug_policy) for img in fake_img_prev] if isinstance(fake_img_prev, list) else DiffAugment(fake_img_prev, policy=diff_aug_policy)

                _, fake_logits_prev = discriminator(fake_img_prev)

                g_gan_loss_prev = torch.mean(F.logsigmoid(real_logits.detach()) + F.logsigmoid(-fake_logits_prev))

                g_h_loss_prev = - _compute_h_loss(real_logits.detach(), fake_logits_prev)
                g_loss_prev = g_gan_loss_prev + (0.5 * current_rho) * g_h_loss_prev

                torch.autograd.backward(g_loss_prev, inputs=tuple(generator_prev.parameters()))
                
                for name, p in generator_prev.named_parameters():
                    if p.grad is not None:
                        grad_prev_buffer[name] = p.grad.detach()
                    else:
                        grad_prev_buffer[name] = torch.zeros_like(p)
                
                generator_prev.zero_grad(set_to_none=True)

            optimizer_D.step()
            if lr_decay:
                scheduler_D.step()
            
            # ---- Train Generator ----
            generator.zero_grad()
            if use_storm:
                storm_curr_state = {k: v.detach().clone() for k, v in generator.state_dict().items()}
            
            _, fake_logits_g = discriminator(fake_img)
            
            g_gan_loss = torch.mean(F.logsigmoid(real_logits.detach()) + F.logsigmoid(-fake_logits_g))

            g_h_loss = - _compute_h_loss(real_logits.detach(), fake_logits_g)
            g_loss = g_gan_loss + (0.5 * current_rho) * g_h_loss
            
            g_loss.backward()

            if use_storm:
                eta_k = eta_scheduler.get_eta()

                for name, p in generator.named_parameters():
                    grad_curr = p.grad if p.grad is not None else torch.zeros_like(p)
                    # Use pre-calculated gradient
                    grad_prev = grad_prev_buffer[name] if (storm_has_prev and name in grad_prev_buffer) else torch.zeros_like(p)
                    
                    d_prev = storm_buffer[name]
                    d_new = grad_curr + (1 - eta_k) * (d_prev - grad_prev)
                    storm_buffer[name] = d_new.detach()

                    if p.grad is None:
                        p.grad = d_new.detach()
                    else:
                        p.grad.detach().copy_(d_new)

                eta_scheduler.step()
                optimizer_G.step()
                storm_prev_state = storm_curr_state
                storm_has_prev = True
            else:
                optimizer_G.step()

            ema_G.update(generator)
            if rho_increase:
                scheduler_rho.step()
            if lr_decay:
                scheduler_G.step()
            
        # ---- Logging ----
        if is_eval_epoch:
            # Metrics logging
            log_items = [
                f"epoch: {epoch + 1:3d}",
                f"d_loss: {- d_gan_loss.item():.3f}",
                f"g_loss: {g_gan_loss.item():.3f}",
                f"h_loss: {d_h_loss.item():.3f}",
            ]
            
            metrics = evaluate_metrics(ema_G.shadow, test_dataloader, batch_size, problem_cfg.z_dim, fid=fid, inception_score=inception_score, num_samples=None)
            log_items.extend([
                f"fid: {metrics['fid']:.3f}",
                f"is: {metrics['is_mean']:.3f}±{metrics['is_std']:.3f}",
                f"rho: {scheduler_rho.get_rho():.3f}",
                f"lr_D: {scheduler_D.get_lr()[0]:.3e}",
                f"lr_G: {scheduler_G.get_lr()[0]:.3e}",
                f"decay_D: {scheduler_D.get_weight_decay()[0]:.3e}",
                f"eta_D: {eta_scheduler.get_eta():.3e}" if use_storm else "",
            ])

            log.info(", ".join(log_items))

            swanlab.log({
                "train/d_loss": - d_gan_loss.item(),
                "train/g_loss": g_gan_loss.item(),
                "train/h_loss": d_h_loss.item(),
                "train/time": time.perf_counter() - start_time,
                "eval/fid": metrics['fid'],
                "eval/is_mean": metrics['is_mean'],
                "eval/is_std": metrics['is_std'],
                # "params/lr_D": scheduler_D.get_lr()[0],
                # "params/lr_G": scheduler_G.get_lr()[0],
                # "params/weight_decay_D": scheduler_D.get_weight_decay()[0],
                # "params/weight_decay_G": scheduler_G.get_weight_decay()[0],
            }, step=epoch + 1)

        # ---- Monitoring images with fixed noise ----
        if is_sampling_epoch:
            save_path = os.path.join(problem_cfg.out_dir, 'images', f"{epoch + 1}.png")
            save_image_grid(ema_G.shadow, fixed_z, save_path)
                        
    log.info("Training finished.")

    save_model(
        model=ema_G.shadow,
        model_save_dir=cfg.model_save_dir,
        dataset=problem_cfg.dataset,
        algo_name=algo_cfg.name,
        model_name=problem_cfg.model,
        seed=problem_cfg.seed,
        cfg=cfg,
        filename="generator.pt",
    )


if __name__ == "__main__":
    train()
