# file: prism/systems/prism_system.py
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F

from prism.core.base_objects import BaseSystem
from prism.core.registry import LOSSES, MODELS, SYSTEMS
from prism.utils.model_analysis import save_model_graph, save_model_summary


@SYSTEMS.register("PrismSystem")
class PrismSystem(BaseSystem):
    def __init__(self, config, fixed_observer=None):
        super().__init__(config)
        self.automatic_optimization = False

        self._encoder_noise_scale = float(self.config.model.latent_space.encoder_noise_scale)

        self._initialize_models()
        self._initialize_losses()
        self._initialize_buffers()

        self.validation_step_outputs = []
        self.test_step_outputs = []

        self._model_analysis_complete = False
        self._is_scheduled_epoch = False

    def _initialize_models(self):
        autoencoder = MODELS.get("Autoencoder")(self.config)
        self.encoder = autoencoder.encoder
        self.generator = autoencoder.generator

        self.classifier = MODELS.get("Classifier")(self.config)
        self.latent_discriminator = MODELS.get("LatentDiscriminator")(self.config)
        self.discriminator_q = MODELS.get("DiscriminatorQ")(self.config)

    def _initialize_losses(self):
        self.recon_loss_fn = LOSSES.get(self.config.loss.recon_loss_type)(self.config)
        self.gan_loss_fn = LOSSES.get(self.config.loss.gan_loss_type)(self.config)
        self.class_loss_fn = nn.CrossEntropyLoss()
        self.info_loss_fn = nn.GaussianNLLLoss()

    def _initialize_buffers(self):
        prototype_shape = (self.config.data.num_classes, self.config.model.latent_space.target_dim)
        self.register_buffer('prototypes', torch.zeros(prototype_shape))

    def _gc_settings(self):
        gc = self.config.training.gradient_clipping
        if not gc.enabled:
            return None

        main_val = float(gc.main_clip_val)
        adv_val = float(gc.adversarial_clip_val)

        if main_val <= 0 and adv_val <= 0:
            return None

        return {"algorithm": gc.algorithm, "main_val": main_val, "adv_val": adv_val}

    def setup(self, stage):
        if stage == 'fit':
            vgg_loss_classes = (
                LOSSES.get('vgg_single_layer'),
                LOSSES.get('vgg_multi_layer')
            )

            if isinstance(self.recon_loss_fn, vgg_loss_classes):
                self.recon_loss_fn.to(self.device)

            self._run_model_analysis()

    def _run_model_analysis(self):
        if not (self.trainer.is_global_zero and self.config.analysis.log_model_architecture and not self._model_analysis_complete):
            return

        analysis_dir = Path(self.trainer.logger.log_dir) / "model_analysis"
        analysis_dir.mkdir(parents=True, exist_ok=True)

        graph_depth = self.config.analysis.graph_depth
        data_cfg = self.config.data
        model_cfg = self.config.model.latent_space

        batch_size = 2
        dummy_image = torch.randn(batch_size, *data_cfg.image_shape, device=self.device)
        dummy_latent = torch.randn(batch_size, model_cfg.latent_dim, device=self.device)
        dummy_z1 = torch.randn(batch_size, model_cfg.target_dim, device=self.device)
        dummy_z0 = torch.randn(batch_size, model_cfg.nontarget_dim, device=self.device)

        models_to_analyze = {
            "01_encoder": (self.encoder, dummy_image),
            "02_generator": (self.generator, dummy_latent),
            "03_classifier": (self.classifier, dummy_z1),
            "04_latent_discriminator": (self.latent_discriminator, dummy_z0),
            "05_discriminator_q": (self.discriminator_q, dummy_image),
        }

        for name, (model, dummy_input) in models_to_analyze.items():
            save_model_summary(model, dummy_input, analysis_dir, name, depth=graph_depth)
            save_model_graph(model, dummy_input, analysis_dir, name, graph_depth=graph_depth)

        self._model_analysis_complete = True

    def _should_run_expensive_callbacks(self):
        if self.trainer.sanity_checking:
            return False

        epoch = self.current_epoch
        save_interval = self.config.evaluation.intervention_interval
        is_scheduled_epoch = (epoch + 1) % save_interval == 0
        is_last_epoch = (epoch + 1) == self.trainer.max_epochs

        return is_scheduled_epoch or is_last_epoch

    def forward(self, x):
        z = self.encoder(x)
        if self.training and self._encoder_noise_scale > 0:
            z = z + self._encoder_noise_scale * torch.randn_like(z)
        x_rec = self.generator(z)
        return x_rec, z

    def _r1_warmup_scale(self):
        r1_cfg = self.config.loss.r1_penalty

        if r1_cfg.warmup_steps > 0:
            return min(1.0, self.global_step / float(r1_cfg.warmup_steps))

        if r1_cfg.warmup_epochs > 0:
            current_epoch = self.current_epoch
            return min(1.0, (current_epoch + 1) / float(r1_cfg.warmup_epochs))

        return 1.0

    def _compute_adversary_losses(self, data, target_labels, z_nontarget, x_rec):
        loss_cfg = self.config.loss

        ld_logits, _ = self.latent_discriminator(z_nontarget.detach())
        loss_d_l = self.class_loss_fn(ld_logits, target_labels)

        d_output_real, _, _ = self.discriminator_q(data)
        loss_d_real = self.gan_loss_fn(d_output_real, target_is_real=True)

        d_output_fake, _, _ = self.discriminator_q(x_rec.detach())
        loss_d_fake = self.gan_loss_fn(d_output_fake, target_is_real=False)
        loss_d_gan = (loss_d_real + loss_d_fake) / 2

        prior = torch.randn_like(z_nontarget)
        _, d_prior_real = self.latent_discriminator(prior)
        _, d_prior_fake = self.latent_discriminator(z_nontarget.detach())
        loss_d_prior = (self.gan_loss_fn(d_prior_real, target_is_real=True) +
                        self.gan_loss_fn(d_prior_fake, target_is_real=False)) / 2

        loss_r1 = torch.tensor(0.0, device=self.device)
        warmup_scale = self._r1_warmup_scale()
        effective_gamma = loss_cfg.r1_penalty.gamma_r1 * warmup_scale

        if effective_gamma > 0 and (self.global_step % loss_cfg.r1_penalty.interval == 0):
            data.requires_grad_()
            pred_real = self.discriminator_q(data)[0].sum()
            grad_real = torch.autograd.grad(outputs=pred_real, inputs=data, create_graph=True)[0]
            grad_penalty = torch.sum(grad_real.pow(2), dim=list(range(1, grad_real.ndim)))
            loss_r1 = grad_penalty.mean() * (effective_gamma * 0.5)

        noise = torch.randn_like(x_rec) * loss_cfg.noise_std_for_q
        _, q_mu, q_logvar = self.discriminator_q(x_rec.detach() + noise)
        loss_q = self.info_loss_fn(q_mu, z_nontarget.detach(), torch.exp(q_logvar))

        total_loss = (loss_cfg.weights.gamma_l * loss_d_l +
                      loss_cfg.weights.gamma_gan * loss_d_gan +
                      loss_cfg.weights.gamma_info * loss_q +
                      loss_cfg.weights.gamma_prior * loss_d_prior +
                      loss_r1)

        return {
            'd_l': loss_d_l, 'd_gan': loss_d_gan, 'q': loss_q,
            'd_prior': loss_d_prior, 'r1': loss_r1, 'total': total_loss
        }

    def _compute_main_losses(self, data, target_labels, z_target, z_nontarget, x_rec):
        loss_cfg = self.config.loss

        loss_rec = self.recon_loss_fn(x_rec, data)
        class_logits = self.classifier(z_target)
        loss_cls = self.class_loss_fn(class_logits, target_labels)

        ld_logits_adv, _ = self.latent_discriminator(z_nontarget)
        uniform_dist = torch.full_like(ld_logits_adv, 1.0 / self.config.data.num_classes)
        loss_adv_l = F.kl_div(F.log_softmax(ld_logits_adv, dim=1), uniform_dist, reduction='batchmean')

        d_output_adv, _, _ = self.discriminator_q(x_rec)
        loss_adv_gan = self.gan_loss_fn(d_output_adv, target_is_real=True)

        noise = torch.randn_like(x_rec) * loss_cfg.noise_std_for_q
        _, q_mu, q_logvar = self.discriminator_q(x_rec + noise)
        loss_info = self.info_loss_fn(q_mu, z_nontarget, torch.exp(q_logvar))

        _, d_prior_adv = self.latent_discriminator(z_nontarget)
        loss_adv_prior = self.gan_loss_fn(d_prior_adv, target_is_real=True)

        prototypes_for_batch = self.prototypes[target_labels].detach()
        loss_proto = F.mse_loss(z_target, prototypes_for_batch)

        total_loss = (loss_cfg.weights.gamma_rec * loss_rec +
                      loss_cfg.weights.gamma_cls * loss_cls +
                      loss_cfg.weights.gamma_l * loss_adv_l +
                      loss_cfg.weights.gamma_gan * loss_adv_gan +
                      loss_cfg.weights.gamma_info * loss_info +
                      loss_cfg.weights.gamma_proto * loss_proto +
                      loss_cfg.weights.gamma_prior * loss_adv_prior)
        return {
            'rec': loss_rec, 'cls': loss_cls, 'adv_l': loss_adv_l,
            'adv_gan': loss_adv_gan, 'info': loss_info, 'proto': loss_proto,
            'adv_prior': loss_adv_prior, 'total': total_loss
        }

    def training_step(self, batch, batch_idx):
        opt_main, opt_adv = self.optimizers()
        data, target_labels, *_ = batch
        latent_cfg = self.config.model.latent_space

        z = self.encoder(data)
        if self._encoder_noise_scale > 0:
            z = z + self._encoder_noise_scale * torch.randn_like(z)

        target_slice = slice(latent_cfg.target_slice_start, latent_cfg.target_slice_stop)
        nontarget_slice = slice(latent_cfg.nontarget_slice_start, latent_cfg.nontarget_slice_stop)
        z_target = z[:, target_slice]
        z_nontarget = z[:, nontarget_slice]
        x_rec = self.generator(z)

        opt_adv.zero_grad()
        adv_losses = self._compute_adversary_losses(data, target_labels, z_nontarget, x_rec)
        self.manual_backward(adv_losses['total'])

        gc = self._gc_settings()
        if gc and gc["adv_val"] > 0:
            self.clip_gradients(opt_adv, gradient_clip_val=gc["adv_val"], gradient_clip_algorithm=gc["algorithm"])

        opt_adv.step()

        opt_main.zero_grad()
        main_losses = self._compute_main_losses(data, target_labels, z_target, z_nontarget, x_rec)
        self.manual_backward(main_losses['total'])

        if gc and gc["main_val"] > 0:
            self.clip_gradients(opt_main, gradient_clip_val=gc["main_val"], gradient_clip_algorithm=gc["algorithm"])

        opt_main.step()

        with torch.no_grad():
            for cls_idx in torch.unique(target_labels):
                mask = (target_labels == cls_idx)
                if mask.sum() > 0:
                    batch_mean = z_target[mask].mean(dim=0)
                    self.prototypes[cls_idx].lerp_(batch_mean, 1.0 - self.config.loss.prototype_momentum)

        log_dict = {f"train_loss/main_{k}": v for k, v in main_losses.items()}
        log_dict.update({f"train_loss/adv_{k}": v for k, v in adv_losses.items()})
        self.log_dict(log_dict, on_step=True, on_epoch=False, logger=True)

        return {'x_rec': x_rec.detach(), 'data': data}

    def on_validation_epoch_start(self):
        self.validation_step_outputs.clear()
        self._is_scheduled_epoch = self._should_run_expensive_callbacks()
        if self.trainer.is_global_zero and self._is_scheduled_epoch:
            print(f"\n--- Running expensive callbacks for Epoch {self.current_epoch + 1} ---")

    def validation_step(self, batch, batch_idx):
        data, target_labels, style_labels = batch
        x_rec, z = self(data)

        latent_cfg = self.config.model.latent_space
        target_slice = slice(latent_cfg.target_slice_start, latent_cfg.target_slice_stop)
        z_target = z[:, target_slice]
        class_logits = self.classifier(z_target)
        val_loss = self.class_loss_fn(class_logits, target_labels)
        self.log('val/class_loss', val_loss, on_step=False, on_epoch=True, sync_dist=True)

        if self._is_scheduled_epoch:
            self.validation_step_outputs.append({
                "z": z.cpu(),
                "target_labels": target_labels.cpu(),
                "style_labels": style_labels.cpu(),
                "data": data.cpu()
            })
            return {"x_rec": x_rec, "data": data}

        return None

    def test_step(self, batch, batch_idx):
        data, target_labels, style_labels = batch
        x_rec, z = self(data)
        self.test_step_outputs.append({
            "z": z.cpu(),
            "target_labels": target_labels.cpu(),
            "style_labels": style_labels.cpu(),
            "data": data.cpu()
        })

        return {"x_rec": x_rec, "data": data}

    def configure_optimizers(self):
        opt_cfg = self.config.training.optimizer
        main_params = list(self.encoder.parameters()) + list(self.generator.parameters()) + list(self.classifier.parameters())
        adv_params = (list(self.latent_discriminator.parameters()) +
                      list(self.discriminator_q.parameters()))

        optimizer_main = torch.optim.AdamW(
            main_params, lr=opt_cfg.main.lr,
            betas=tuple(opt_cfg.betas), weight_decay=opt_cfg.weight_decay
        )
        optimizer_adv = torch.optim.AdamW(
            adv_params, lr=opt_cfg.adversarial.lr,
            betas=tuple(opt_cfg.betas), weight_decay=opt_cfg.weight_decay
        )
        return optimizer_main, optimizer_adv