import torch
from tqdm import tqdm

from dae.utils.torch_utils import freeze_model, reproducible_rand, unwrap

from ..models.ae.losses import AUX_LOSSES, GanLoss
from ..registers import AUTOENCODERS
from .base_tasks import TASKS, BaseAutoencodingTask

####################################################################
# DAE Tasks
####################################################################


@TASKS.register("ae")
class AETasks(BaseAutoencodingTask):
    """
    - dae_train -> train the DAE model
    - dae_eval -> multi-thread, fast evaluation of all metrics (almost exact)
    - dae_fid -> single-thread, FID evaluation, get exact final metrics
    - dae_files_fid -> old implementation, deprecated, uses 'dae_fid' instead
    """

    SHOW_MODEL_PARTS = ["decoder", "decoder.down", "decoder.down_blocks", "vae"]

    def load_models(self):
        self._build_model(AUTOENCODERS, name="ae", **self.cfg.ae)
        if self.training:
            self._build_model(AUX_LOSSES, name="aux_losses", **self.cfg.aux_losses, ae=self.models["ae"], accelerator=self.accelerator)

            if self.cfg.get("gan", None) is not None:
                model_gan = GanLoss(
                    model_last_layer=unwrap(self.models["ae"], unw_ema=True).get_last_layer_weight(),
                    **self.cfg.gan,
                )
                self.prepare_model(model_gan, name="gan")

            if self.cfg.get("teacher", None) is not None:
                self._build_model(AUTOENCODERS, name="teacher", **self.cfg.teacher, remove_from_checkpointing=True)
                freeze_model(self.models["teacher"])
                self.models["teacher"].train()

    def _compute_train_loss(self, batch, models_names, train_ctx):
        x, _ = batch
        losses = {}

        if "ae" in models_names:
            target_x = None
            # Train DAE: main loss & predict x0
            if "teacher" in self.models:
                # Train on distillation
                with torch.no_grad():
                    teacher_gen = self.models["teacher"](x)
                target_x = teacher_gen.x0_pred
                dae_out = self.models["ae"](target_x, z=teacher_gen.z, noise=teacher_gen.noise)
            else:
                # Train on target
                dae_out = self.models["ae"](x)
            losses.update(dae_out.losses)

            # Add auxiliary losses
            if "aux_losses" in models_names:
                losses.update(self.models["aux_losses"](x, dae_out.x0_pred, target_x=target_x))

            # Add GAN losses
            if "gan" in self.models:
                losses.update(
                    self.models["gan"](
                        x_gt=x if target_x is None else target_x,
                        x_pred=dae_out.x0_pred,
                        xt=dae_out.xt,
                        t=dae_out.t,
                        existing_losses=losses,
                        n_train_steps=train_ctx["cur_steps"],
                        step="disc_loss",
                    )
                )
                train_ctx["gan_ctx"] = {
                    "x_pred": dae_out.x0_pred,
                    "xt": dae_out.xt,
                    "t": dae_out.t,
                }

        if "gan" in models_names:
            losses.update(
                self.models["gan"](
                    x_gt=x,
                    **train_ctx["gan_ctx"],
                    n_train_steps=train_ctx["cur_steps"],
                    step="train",
                )
            )

        return losses

    def _generate_for_eval(self, x, generator=None):
        noise = reproducible_rand(self.accelerator, generator, x.shape)
        gen_x = self.models["ae"](x, noise=noise)
        return gen_x

    @torch.no_grad()
    def task_z_stats(self):
        acc = self.accelerator
        self.models["ae"].eval()
        tqdm_dis = not acc.is_main_process or not self.cfg.verbose
        assert acc.is_main_process, "Z stats can only be computed on the main process"

        z_dataset = []

        enum_tests = tqdm(self.test_loader, desc="Reconstructing from test set", disable=tqdm_dis)
        for x, _ in enum_tests:
            x = x.to(acc.device)
            with acc.autocast():
                z = self.models["ae"].encode(x).mode()
            z = z.to(torch.float64)

            z_dataset.append(z.cpu())

        z_dataset = torch.cat(z_dataset, dim=0)
        z_mean = z_dataset.mean()
        z_std = z_dataset.std(unbiased=True)

        z_mean, z_std = z_mean.item(), z_std.item()
        acc.print(f"Z stats: z_mean={z_mean:.5f} z_std={z_std:.5f}")
        return {"z_mean": z_mean, "z_std": z_std}
