import numpy as np
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from sympy import plot
import torch
from torch import optim
from torch.nn import functional as F
from torchvision import utils as vutils
from torchvision.transforms import Resize
from slot_attention.losses.ari import preprocess_and_batch_aris

from slot_attention.model.model_slatn import SlotAttentionModel
from slot_attention.model.model_utils import Tensor
from slot_attention.model.model_utils import to_rgb_from_tensor
from pytorch_lightning.utilities import grad_norm
from slot_attention.visualization.vis_masks import visualize_segmentation_masks

from slot_attention.visualization.visualization_carrier import VisualizationCarrier


class SlotAttentionMethod(pl.LightningModule):
    def __init__(self, model: SlotAttentionModel, datamodule: pl.LightningDataModule, params):
        super().__init__()
        self.model = model
        self.datamodule = datamodule
        self.params = params
        self.validation_step_outputs = []
        self.vis_carrier = VisualizationCarrier(self.params)
        self.resize = Resize((self.params.resolution[0] // 2, self.params.resolution[1] // 2))

    def forward(self, input: Tensor, **kwargs) -> Tensor:
        return self.model(input, **kwargs)

    def sample_images(self):
        dl = self.datamodule.val_dataloader()
        # perm = torch.randperm(self.params.batch_size)
        batch, in_masks = next(iter(dl))
        
        # shuffle batch
        perm = np.random.permutation(self.params.val_batch_size)
        batch = batch[perm]
        in_masks = in_masks[perm]
        
            
        if self.params.n_gpus > 0:
            batch = batch.to(self.device)
            in_masks = in_masks.to(self.device)
        
        plt.close('all') 
        self.vis_carrier.reset()
        
        recon_combined, recons, soft_pred_masks, slots = self.model.forward(batch, vis_carrier=self.vis_carrier)
        img = batch[0,:,:,:]
        if self.params.use_average_pool:
            img  = self.resize(img)
        self.vis_carrier.add_image(to_rgb_from_tensor(img).cpu().numpy().transpose(1,2,0))
        plot_dict = self.vis_carrier.create_plots()

        # get images of random idxs
        idx = torch.arange(self.params.n_samples)
        images_failures = self.get_images_of_idcs(idx, batch, recon_combined, recons, soft_pred_masks)
        plot_dict['images'] = images_failures
        
        soft_pred_masks = soft_pred_masks[:, :, 0, :, :]
        aris = preprocess_and_batch_aris(true_masks=in_masks, soft_pred_masks=soft_pred_masks)
        soft_pred_masks = soft_pred_masks[:, :, None, :, :]
        sort_ari_idcs = torch.argsort(aris[1], descending=True)
        sort_ari_idcs = sort_ari_idcs[:self.params.n_samples]
        images_failures = self.get_images_of_idcs(sort_ari_idcs, batch, recon_combined, recons, soft_pred_masks)
        plot_dict['images_failures'] = images_failures
        
        plot_dict['segmentations'] = visualize_segmentation_masks(batch[0].cpu().numpy(), in_masks[0].cpu().numpy())
        return plot_dict

    def get_images_of_idcs(self, idx, batch, recon_combined, recons, soft_pred_masks):
        
        batch = batch[idx]        
        recon_combined = recon_combined[idx]
        recons = recons[idx]
        soft_pred_masks = soft_pred_masks[idx]
        batch_size, num_slots, C, H, W = recons.shape
        
        # print('recons.shape', recons.shape)
        # print('soft_pred_masks.shape', soft_pred_masks.shape)
        
        # combine images in a nice way so we can display all outputs in one grid, output rescaled to be between 0 and 1
        out = to_rgb_from_tensor(
            torch.cat(
                [
                    batch.unsqueeze(1),  # original images
                    recon_combined.unsqueeze(1),  # reconstructions
                    recons * soft_pred_masks + (1 - soft_pred_masks),  # each slot
                ],
                dim=1,
            )
        )  # n_samples, n_slots + 2, C, H, W

        images = vutils.make_grid(
            out.view(batch_size * out.shape[1], C, H, W).cpu(), normalize=False, nrow=out.shape[1],
        )
        
        return images

    def training_step(self, batch, batch_idx, optimizer_idx=0):
        imgs, true_masks = batch
        recon_combined, recons, soft_pred_masks, slots = self.forward(imgs)
        soft_pred_masks = soft_pred_masks.squeeze()
        if self.params.l1_loss:
            loss = F.l1_loss(recon_combined, imgs)
        else:
            loss = F.mse_loss(recon_combined, imgs)

        # with torch.no_grad():
        #     recons, soft_pred_masks = merge_reconstructions_on_similarity(recons, soft_pred_masks)

        # other metrics
        with torch.no_grad():
            b_fg_ari, b_ari = preprocess_and_batch_aris(true_masks=true_masks, soft_pred_masks=soft_pred_masks)
            logs = {
                    "loss": loss,
                    'ari': torch.mean(b_ari),
                    'fg_ari': torch.mean(b_fg_ari),
                }
            self.log_dict(logs, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx, optimizer_idx=0):
        imgs, true_masks = batch
        # print(f"imgs.shape: {imgs.shape}")
        # print(f"true_masks.shape: {true_masks.shape}")
        recon_combined, recons, soft_pred_masks, slots = self.forward(imgs)
        soft_pred_masks = soft_pred_masks.squeeze()
        if self.params.l1_loss:
            loss = F.l1_loss(recon_combined, imgs)
        else:
            loss = F.mse_loss(recon_combined, imgs)

        # other metrics
        b_fg_ari, b_ari = preprocess_and_batch_aris(true_masks=true_masks, soft_pred_masks=soft_pred_masks)
        metrics_dict = {
                "loss": loss,
                'ari': torch.mean(b_ari),
                'fg_ari': torch.mean(b_fg_ari),
            }
        self.validation_step_outputs.append(metrics_dict)
        return loss

    def on_validation_epoch_end(self):
        outputs = self.validation_step_outputs
        self.validation_step_outputs = []
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        avg_ari = torch.stack([x["ari"] for x in outputs]).mean()
        avg_fg_ari = torch.stack([x["fg_ari"] for x in outputs]).mean()
        logs = {
            "avg_val_loss": avg_loss,
            "avg_val_ari": avg_ari,
            "avg_val_fg_ari": avg_fg_ari,
        }
        self.log_dict(logs, sync_dist=True)
        # print("; ".join([f"{k}: {v.item():.6f}" for k, v in logs.items()]))

    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), 
                               lr=self.params.lr, 
                               weight_decay=self.params.weight_decay)

        warmup_steps = self.params.warmup_steps
        decay_steps = self.params.decay_steps
        total_steps = self.params.max_epochs * len(self.datamodule.train_dataloader())

        def warm_and_decay_lr_scheduler(step: int):         
            # assert step <= total_steps
            if step < warmup_steps:
                return step / warmup_steps
            
            return self.params.scheduler_gamma ** (step / decay_steps)

        scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=warm_and_decay_lr_scheduler)

        return (
            [optimizer],
            [{"scheduler": scheduler, "interval": "step",}],
        )

    def on_before_optimizer_step(self, optimizer):
        # Compute the 2-norm for each layer
        # If using mixed precision, the gradients are already unscaled here
        if self.params.log_gradients:
            norms = grad_norm(self.model, norm_type=2)
            self.log_dict(norms)

