# file: user_extensions/baselines/fader_networks/fader_system.py
import torch
import torch.nn as nn

from prism.core.base_objects import BaseSystem
from prism.core.registry import LOSSES, MODELS, SYSTEMS


@SYSTEMS.register("FaderSystem")
class FaderSystem(BaseSystem):
    def __init__(self, config):
        super().__init__(config)
        self.automatic_optimization = False

        self.encoder = MODELS.get("Encoder")(self.config)
        self.generator = MODELS.get("ConditionalGenerator")(self.config)
        self.discriminator = MODELS.get("FaderDiscriminator")(self.config)

        self.recon_loss_fn = LOSSES.get(self.config.loss.recon_loss_type)(self.config)
        self.class_loss_fn = nn.CrossEntropyLoss()

        self.gamma_rec = self.config.loss.weights.gamma_rec
        self.lambda_fader = self.config.loss.weights.lambda_fader

        self.lambda_schedule = self.config.loss.get('lambda_schedule', 0)

        self.validation_step_outputs = []
        self.test_step_outputs = []
        self._is_scheduled_epoch = False

    def configure_optimizers(self):
        opt_cfg = self.config.training.optimizer
        ae_params = list(self.encoder.parameters()) + list(self.generator.parameters())
        optimizer_ae = torch.optim.AdamW(
            ae_params, lr=opt_cfg.main.lr, betas=tuple(opt_cfg.betas), weight_decay=opt_cfg.weight_decay
        )
        disc_params = self.discriminator.parameters()
        optimizer_disc = torch.optim.AdamW(
            disc_params, lr=opt_cfg.adversarial.lr, betas=tuple(opt_cfg.betas), weight_decay=opt_cfg.weight_decay
        )
        return optimizer_ae, optimizer_disc

    def training_step(self, batch, batch_idx):
        opt_ae, opt_disc = self.optimizers()
        data, target_labels, _ = batch

        z_detached = self.encoder(data).detach()
        opt_disc.zero_grad()
        y_pred_disc = self.discriminator(z_detached)
        loss_disc = self.class_loss_fn(y_pred_disc, target_labels)
        self.manual_backward(loss_disc)
        opt_disc.step()

        opt_ae.zero_grad()

        z = self.encoder(data)
        x_rec = self.generator(z, target_labels)
        loss_rec = self.recon_loss_fn(x_rec, data)

        y_pred_adv = self.discriminator(z)
        loss_adv = -self.class_loss_fn(y_pred_adv, target_labels)

        if self.lambda_schedule > 0:
            current_step = self.trainer.global_step
            multiplier = min(1.0, current_step / self.lambda_schedule)
            effective_lambda_fader = self.lambda_fader * multiplier
        else:
            effective_lambda_fader = self.lambda_fader

        loss_ae = (self.gamma_rec * loss_rec) + (effective_lambda_fader * loss_adv)

        self.manual_backward(loss_ae)
        opt_ae.step()

        log_payload = {
            'train_loss/ae_total': loss_ae,
            'train_loss/reconstruction': loss_rec,
            'train_loss/discriminator': loss_disc,
            'train_loss/adversarial_encoder': loss_adv
        }
        if self.lambda_schedule > 0:
            log_payload['train_loss/effective_lambda_fader'] = effective_lambda_fader

        self.log_dict(log_payload, on_step=True, on_epoch=False, logger=True)

    def _should_run_expensive_callbacks(self):
        if self.trainer.sanity_checking:
            return False
        epoch = self.current_epoch
        save_interval = self.config.evaluation.intervention_interval
        is_scheduled_epoch = (epoch + 1) % save_interval == 0
        is_last_epoch = (epoch + 1) == self.trainer.max_epochs
        return is_scheduled_epoch or is_last_epoch

    def on_validation_epoch_start(self):
        self.validation_step_outputs.clear()
        self._is_scheduled_epoch = self._should_run_expensive_callbacks()
        if self.trainer.is_global_zero and self._is_scheduled_epoch:
            print(f"\n--- Running expensive callbacks for Epoch {self.current_epoch + 1} ---")

    def validation_step(self, batch, batch_idx):
        data, target_labels, style_labels = batch
        z = self.encoder(data)
        x_rec = self.generator(z, target_labels)
        loss_rec = self.recon_loss_fn(x_rec, data)

        self.log('val/recon_loss', loss_rec, on_step=False, on_epoch=True, sync_dist=True)

        if self._is_scheduled_epoch:
            self.validation_step_outputs.append({
                "z": z.cpu(),
                "target_labels": target_labels.cpu(),
                "style_labels": style_labels.cpu(),
                "data": data.cpu()
            })

        return {"x_rec": x_rec, "data": data}

    def test_step(self, batch, batch_idx):
        data, target_labels, style_labels = batch
        z = self.encoder(data)
        x_rec = self.generator(z, target_labels)

        self.test_step_outputs.append({
            "z": z.cpu(),
            "target_labels": target_labels.cpu(),
            "style_labels": style_labels.cpu(),
            "data": data.cpu()
        })

        return {"x_rec": x_rec, "data": data}