from os import makedirs, path
from typing import Callable, Dict, Iterable, Tuple

import numpy as np
import piq
import torch
from PIL import Image
from einops import rearrange
from kmeans_pytorch import kmeans
from lightning import LightningModule
from torch import Tensor
from torch.optim import AdamW, Optimizer

OptimizerCallable = Callable[[Iterable], Optimizer]

from mucola.modules import LatentActionModel, DINOHead, MultiViewDINOLoss
from mucola.tuning import AutoTuner


class MuCoLA(LightningModule):
    def __init__(
            self,
            image_channels: int = 3,
            # Latent action autoencoder
            mucola_model_dim: int = 512,
            mucola_latent_dim: int = 32,
            mucola_patch_size: int = 16,
            mucola_enc_blocks: int = 8,
            mucola_dec_blocks: int = 8,
            mucola_num_heads: int = 8,
            mucola_dropout: float = 0.0,
            beta: float = 0.01,
            log_interval: int = 1000,
            log_path: str = "log_imgs",
            optimizer: OptimizerCallable = AdamW,
            enable_multiview: bool = False,
            dino_loss_weight: float = 0.0,
            proto_dim: int = 4096,
            teacher_temp: float = 0.07,
            student_temp: float = 0.1,
            center_momentum: float = 0.9,
            momentum_teacher: float = 0.996
    ) -> None:
        super(MuCoLA, self).__init__()
        self.mucola = LatentActionModel(
            in_dim=image_channels,
            model_dim=mucola_model_dim,
            latent_dim=mucola_latent_dim,
            patch_size=mucola_patch_size,
            enc_blocks=mucola_enc_blocks,
            dec_blocks=mucola_dec_blocks,
            num_heads=mucola_num_heads,
            dropout=mucola_dropout
        )
        self.beta = beta
        self.log_interval = log_interval
        self.log_path = log_path
        self.optimizer = optimizer
        self.enable_multiview = enable_multiview
        self.dino_loss_weight = dino_loss_weight

        if self.enable_multiview:
            self.teacher = LatentActionModel(
                in_dim=image_channels,
                model_dim=mucola_model_dim,
                latent_dim=mucola_latent_dim,
                patch_size=mucola_patch_size,
                enc_blocks=mucola_enc_blocks,
                dec_blocks=mucola_dec_blocks,
                num_heads=mucola_num_heads,
                dropout=mucola_dropout
            )
            for p in self.teacher.parameters():
                p.requires_grad = False
            self.student_head = DINOHead(mucola_latent_dim, proto_dim)
            self.teacher_head = DINOHead(mucola_latent_dim, proto_dim)
            for p in self.teacher_head.parameters():
                p.requires_grad = False
            self.dino_loss = MultiViewDINOLoss(teacher_temp, student_temp, center_momentum)
            self.momentum_teacher = momentum_teacher
            self.autotuner = AutoTuner({
                "teacher_temp": teacher_temp,
                "student_temp": student_temp,
            })

        self.save_hyperparameters()

    def shared_step(self, batch: Dict) -> Tuple:
        outputs = self.mucola(batch)
        gt_future_frames = batch["videos"][:, 1:]

        # Compute loss
        mse_loss = ((gt_future_frames - outputs["recon"]) ** 2).mean()
        kl_loss = -0.5 * torch.sum(1 + outputs["z_var"] - outputs["z_mu"] ** 2 - outputs["z_var"].exp(), dim=1).mean()
        loss = mse_loss + self.beta * kl_loss

        # Compute monitoring measurements
        gt = gt_future_frames.clamp(0, 1).reshape(-1, *gt_future_frames.shape[2:]).permute(0, 3, 1, 2)
        recon = outputs["recon"].clamp(0, 1).reshape(-1, *outputs["recon"].shape[2:]).permute(0, 3, 1, 2)
        psnr = piq.psnr(gt, recon).mean()
        ssim = piq.ssim(gt, recon).mean()
        return outputs, loss, (
            ("mse_loss", mse_loss),
            ("kl_loss", kl_loss),
            ("psnr", psnr),
            ("ssim", ssim)
        )

    def training_step(self, batch: Dict, batch_idx: int) -> Tensor:
        if not self.enable_multiview or ("view1" not in batch) or ("view2" not in batch):
            outputs, loss, aux_losses = self.shared_step(batch)
            self.log_dict(
                {**{"train_loss": loss}, **{f"train/{k}": v for k, v in aux_losses}},
                prog_bar=True,
                logger=True,
                on_step=True,
                on_epoch=True,
                sync_dist=True
            )
            self.log(
                "global_step",
                self.global_step,
                prog_bar=True,
                logger=True,
                on_step=True,
                on_epoch=False
            )
            if batch_idx % self.log_interval == 0:
                self.log_images(batch, outputs, "train")
            return loss

        v1 = batch["view1"]["videos"].contiguous()
        v2 = batch["view2"]["videos"].contiguous()
        s1 = self.mucola.latent_actions(v1)
        s2 = self.mucola.latent_actions(v2)
        with torch.no_grad():
            t1 = self.teacher.latent_actions(v1)
            t2 = self.teacher.latent_actions(v2)
        z1 = s1["z_rep"].squeeze(2)
        z2 = s2["z_rep"].squeeze(2)
        tz1 = t1["z_rep"].squeeze(2)
        tz2 = t2["z_rep"].squeeze(2)
        p1 = self.student_head(z1.reshape(-1, z1.shape[-1]))
        p2 = self.student_head(z2.reshape(-1, z2.shape[-1]))
        tp1 = self.teacher_head(tz1.reshape(-1, tz1.shape[-1]))
        tp2 = self.teacher_head(tz2.reshape(-1, tz2.shape[-1]))
        dino_loss = self.dino_loss(p1, p2, tp1, tp2)
        recon1 = self.mucola.reconstruct(v1, s1["z_rep"]).clamp(0, 1)
        recon2 = self.mucola.reconstruct(v2, s2["z_rep"]).clamp(0, 1)
        gt1 = v1[:, 1:]
        gt2 = v2[:, 1:]
        mse1 = ((gt1 - recon1) ** 2).mean()
        mse2 = ((gt2 - recon2) ** 2).mean()
        kl1 = -0.5 * torch.sum(1 + s1["z_var"] - s1["z_mu"] ** 2 - s1["z_var"].exp(), dim=1).mean()
        kl2 = -0.5 * torch.sum(1 + s2["z_var"] - s2["z_mu"] ** 2 - s2["z_var"].exp(), dim=1).mean()
        loss = (mse1 + mse2) / 2 + self.beta * (kl1 + kl2) / 2 + self.dino_loss_weight * dino_loss
        self.log_dict(
            {
                "train_loss": loss,
                "train/mse_v1": mse1,
                "train/mse_v2": mse2,
                "train/kl_v1": kl1,
                "train/kl_v2": kl2,
                "train/dino": dino_loss,
            },
            prog_bar=True,
            logger=True,
            on_step=True,
            on_epoch=True,
            sync_dist=True,
        )
        self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False)
        if batch_idx % self.log_interval == 0:
            self.log_images({"videos": v1}, {"recon": recon1}, "train_v1")
            self.log_images({"videos": v2}, {"recon": recon2}, "train_v2")
        updates = self.autotuner.update({
            "train_loss": loss.item(),
            "train/dino": dino_loss.item(),
        })
        self.autotuner.apply(self.dino_loss, updates)
        return loss

    def on_train_batch_end(self, outputs, batch, batch_idx):
        if self.enable_multiview:
            with torch.no_grad():
                for ps, pt in zip(self.mucola.parameters(), self.teacher.parameters()):
                    pt.data.mul_(self.momentum_teacher).add_(ps.data, alpha=1 - self.momentum_teacher)
                for ps, pt in zip(self.student_head.parameters(), self.teacher_head.parameters()):
                    pt.data.mul_(self.momentum_teacher).add_(ps.data, alpha=1 - self.momentum_teacher)

    # @torch.no_grad()
    # def validation_step(self, batch: Dict, batch_idx: int) -> Tensor:
    #     # Compute the validation loss
    #     outputs, loss, aux_losses = self.shared_step(batch)
    #
    #     # Log the validation loss
    #     self.log_dict(
    #         {**{"val_loss": loss}, **{f"val/{k}": v for k, v in aux_losses}},
    #         prog_bar=True,
    #         logger=True,
    #         on_step=True,
    #         on_epoch=True,
    #         sync_dist=True
    #     )
    #
    #     if batch_idx % self.log_interval == 0:  # Start of the epoch
    #         self.log_images(batch, outputs, "val")
    #     return loss

    @torch.no_grad()
    def test_step(self, batch: Dict, batch_idx: int) -> Tensor:
        # Compute the test loss
        outputs, loss, aux_losses = self.shared_step(batch)

        # Log the test loss
        self.log_dict(
            {**{"test_loss": loss}, **{f"test/{k}": v for k, v in aux_losses}},
            prog_bar=True,
            logger=True,
            on_step=True,
            on_epoch=True,
            sync_dist=True
        )

        self.log_images(batch, outputs, "test")
        return loss

    def log_images(self, batch: Dict, outputs: Dict, split: str) -> None:
        gt_seq = batch["videos"][0].clamp(0, 1).cpu()
        recon_seq = outputs["recon"][0].clamp(0, 1).cpu()
        recon_seq = torch.cat([gt_seq[:1], recon_seq], dim=0)
        compare_seq = torch.cat([gt_seq, recon_seq], dim=1)
        compare_seq = rearrange(compare_seq * 255, "t h w c -> h (t w) c")
        compare_seq = compare_seq.detach().numpy().astype(np.uint8)
        img_path = path.join(self.log_path, f"{split}_step{self.global_step:06}.png")
        makedirs(path.dirname(img_path), exist_ok=True)
        img = Image.fromarray(compare_seq)
        try:
            img.save(img_path)
        except:
            pass

    # def on_test_epoch_end(self) -> None:
    #     # For init specialized world models
    #     torch.save(self.mucola.mu_record, f"latent_action_stats.pt")
    #
    #     # For action creation as generative interactive environments
    #     cluster_ids, cluster_centers = kmeans(
    #         X=self.mucola.mu_record,
    #         num_clusters=8,
    #         distance="euclidean"
    #     )
    #     torch.save(cluster_centers, f"latent_action_centers.pt")

    def configure_optimizers(self) -> Optimizer:
        params = [p for p in self.parameters() if p.requires_grad]
        optim = self.optimizer(params)
        return optim
