import os
import numpy as np
from PIL import Image

from typing import Any, Dict, List, Optional, Tuple
from einops import rearrange, repeat

import torch
import torch.nn.functional as F
import torchvision as tv
from torch import nn
from lightning_utilities.core.rank_zero import rank_zero_only
from torchmetrics import MaxMetric, MeanMetric
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

from .base import BaseSystem
from ..utils.visual import concatenate_images
from ..utils.ops import batch_gather

def binary_cross_entropy(input, target):
    """
    F.binary_cross_entropy is not numerically stable in mixed-precision training.
    """
    return -(target * torch.log(input) + (1 - target) * torch.log(1 - input)).mean()

def mse_loss(input, target, mask = None):
    if mask is None:
        mask = torch.zeros_like(input)
    return ((input - target) ** 2 * mask).mean()

class GeneralizeNERFSystem(BaseSystem):
    def __init__(
        self,
        nerf: torch.nn.Module,
        criterion: torch.nn.Module,
        num_val_dataloaders: int = 1,
        num_test_dataloaders: int = 1,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False, ignore=["nerf"])
        self.nerf = nerf
        self.nerf.set_use_memory_efficient_attention_xformers(True)
        
        self.criterion = criterion

        # metric objects for calculating and averaging accuracy across batches
        self.val_psnr = nn.ModuleList([PeakSignalNoiseRatio(data_range=1.0) for _ in range(num_val_dataloaders)])
        self.val_ssim = nn.ModuleList([StructuralSimilarityIndexMeasure() for _ in range(num_val_dataloaders)])
        self.val_lpips = nn.ModuleList([LearnedPerceptualImagePatchSimilarity(normalize=True) for _ in range(num_val_dataloaders)])
        self.val_psnr_best = nn.ModuleList([MaxMetric() for _ in range(num_val_dataloaders)])

        self.test_psnr = nn.ModuleList([PeakSignalNoiseRatio(data_range=1.0) for _ in range(num_test_dataloaders)])
        self.test_ssim = nn.ModuleList([StructuralSimilarityIndexMeasure() for _ in range(num_test_dataloaders)])
        self.test_lpips = nn.ModuleList([LearnedPerceptualImagePatchSimilarity(normalize=True) for _ in range(num_test_dataloaders)])

    def setup(self, stage: str) -> None:
        super().setup(stage)

    def configure_optimizers(self):
        """Configure optimizers and learning rate schedulers for training."""
        param_groups = []
        param_groups.append({"params": self.nerf.parameters() , "lr": 1.0e-5})

        optimizer = torch.optim.AdamW(param_groups)
        return optimizer

    def on_train_start(self) -> None:
        """Lightning hook that is called when training begins."""
        # by default lightning executes validation step sanity checks before training starts,
        # so it's worth to make sure validation metrics don't store results from these checks
        for metrics in [self.val_psnr, self.val_ssim, self.val_lpips, self.val_psnr_best]:
            if isinstance(metrics, nn.ModuleList):
                for metric in metrics:
                    metric.reset()
            else:
                metrics.reset()

    def training_step(self, batch, batch_idx):
        cameras, images, src_indices, target_indices = batch["cameras"], batch["images"], batch["src_indices"], batch["target_indices"]

        src_images = batch_gather(images, src_indices)
        target_cameras = batch_gather(cameras, target_indices)

        outputs = self.nerf(src_images, target_cameras)
        
        loss, info = self.criterion(batch, outputs)
        self.log_dict(info, prog_bar=True)
        return loss


    def convert_to_image(self, x, trans_fn=None, normalize=False):
        x = x.detach().cpu().numpy()
        x = rearrange(x, "1 m c h w -> h (m w) c")
        if x.shape[-1] == 1:
            x = x[..., 0]
            # normalize to [0, 1]
            if normalize:
                x = (x - x.min()) / (x.max() - x.min())
        x = trans_fn(x)
        x = Image.fromarray(x)
        return x
        

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        outputs = self._generate_images(batch)
        # b f h w c
        trans_funcs = {
            "0-1": lambda x: (x * 255).astype(np.uint8),
            "-1-1": lambda x: ((x + 1) / 2 * 255).astype(np.uint8),
        }
        cameras, images, normals, depths, masks, src_indices, target_indices = batch["cameras"], batch["images"], \
            batch["normals"], batch["depths"], batch["masks"], batch["src_indices"], batch["target_indices"]
        
        target_images = batch_gather(images, target_indices)
        target_depths = batch_gather(depths, target_indices)
        target_masks = batch_gather(masks, target_indices)
        target_normals = batch_gather(normals, target_indices)

        delta_depth = torch.abs(target_depths - outputs['depth'])
        delta_normal = torch.abs(target_normals - outputs['normal'])

        gt_image = self.convert_to_image(target_images, trans_funcs['0-1'])
        pred_image = self.convert_to_image(outputs['rgb'], trans_funcs["0-1"])
        gt_depth = self.convert_to_image(target_depths, trans_funcs["0-1"], normalize=True)
        pred_depth = self.convert_to_image(outputs["depth"], trans_funcs["0-1"], normalize=True)
        delta_depth = self.convert_to_image(delta_depth, trans_funcs['0-1'])
        gt_normal = self.convert_to_image(target_normals, trans_funcs['-1-1'])
        pred_normal = self.convert_to_image(outputs['normal'], trans_funcs['-1-1'])
        delta_normal = self.convert_to_image(delta_normal, trans_funcs['0-1'])

        gt_mask = self.convert_to_image(target_masks, trans_funcs['0-1'])
        pred_mask = self.convert_to_image(outputs['mask'], trans_funcs['0-1'])

        concat_image = concatenate_images([gt_image, pred_image, gt_depth, pred_depth, delta_depth,\
                                            gt_mask, pred_mask, gt_normal, pred_normal, delta_normal], 'v')
        # compute image & save
        image_fp = self._save_image(
            concat_image,
            batch["prompt"],
            f"{dataloader_idx}_{batch_idx}_{self.global_rank}",
            stage="validation",
        )

        # update and log metrics
        images_gt_ = rearrange(
            target_images / 2 + 0.5, "b m c h w -> (b m) c h w"
        ).float()
        
        images_pred_ = torch.tensor(
            rearrange(outputs['rgb'], "b m c h w -> (b m) c h w"),
            dtype=torch.float32,
        ).to(images_gt_.device)

        # self.val_psnr[dataloader_idx](images_gt_, images_pred_)
        # self.val_ssim[dataloader_idx](images_gt_, images_pred_)
        # self.val_lpips[dataloader_idx](images_gt_, images_pred_)
        # self.log(f"val_psnr_{dataloader_idx}", self.val_psnr[dataloader_idx], on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False)
        # self.log(f"val_ssim_{dataloader_idx}", self.val_ssim[dataloader_idx], on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False)
        # self.log(f"val_lpips_{dataloader_idx}", self.val_lpips[dataloader_idx], on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False)

        return image_fp

    def on_validation_epoch_end(self) -> None:
        "Lightning hook that is called when a validation epoch ends."
        for i in range(self.hparams.num_val_dataloaders):
            acc = self.val_psnr[i].compute()  # get current val acc
            self.val_psnr_best[i](acc)  # update best so far val acc
            # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
            # otherwise metric would be reset by lightning after each epoch
            self.log(f"val_psnr_best_{i}", self.val_psnr_best[i].compute(), sync_dist=True, prog_bar=True)

        # log images
        if "wandb" in str(self.logger):
            self._log_to_wandb("validation")

    def test_step(self, batch, batch_idx, dataloader_idx):
        images_pred = self._generate_images(batch)
        images = ((batch["images"] / 2 + 0.5) * 255).cpu().numpy().astype(np.uint8)

        # save images
        image_fp = self._save_image(images_pred, images, batch["prompt"], batch_idx, stage="test")

        # update and log metrics
        images_gt_ = rearrange(
            batch["images"] / 2 + 0.5, "b m c h w -> (b m) c h w"
        ).float()
        images_pred_ = torch.tensor(
            rearrange(images_pred, "b m h w c -> (b m) c h w") / 255.0,
            dtype=torch.float32,
        ).to(images_gt_.device)

        self.test_psnr[dataloader_idx](images_gt_, images_pred_)
        self.test_ssim[dataloader_idx](images_gt_, images_pred_)
        self.test_lpips[dataloader_idx](images_gt_, images_pred_)
        self.log(f"test_psnr_{dataloader_idx}", self.test_psnr[dataloader_idx], on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False)
        self.log(f"test_ssim_{dataloader_idx}", self.test_ssim[dataloader_idx], on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False)
        self.log(f"test_lpips_{dataloader_idx}", self.test_lpips[dataloader_idx], on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False)

        return image_fp

    def on_test_epoch_end(self) -> None:
        """Lightning hook that is called when a test epoch ends."""

        # log images
        if "wandb" in str(self.logger):
            self._log_to_wandb("test")

    @torch.no_grad()
    def _generate_images(self, batch):
        cameras, images, normals, depths, masks, src_indices, target_indices = batch["cameras"], batch["images"], \
            batch["normals"], batch["depths"], batch["masks"], batch["src_indices"], batch["target_indices"]
        src_images = batch_gather(images, src_indices)
        target_cameras = batch_gather(cameras, target_indices)
        outputs = self.nerf(src_images, target_cameras)
        return outputs

    @torch.no_grad()
    @rank_zero_only
    def _save_image(self, im, prompt, batch_idx, stage="validation"):
        save_dir = self.save_dir
        with open(
            os.path.join(save_dir, f"{stage}_{self.global_step}_{batch_idx}.txt"), "w"
        ) as f:
            f.write("\n".join(prompt))

        im_fp = os.path.join(
            save_dir,
            f"{stage}_{self.global_step}_{batch_idx}--{prompt[0].replace(' ', '_').replace('/', '_')}.png",
        )
        im.save(im_fp)
        # add image to logger
        if "tensorboard" in str(self.logger):
            log_image = torch.tensor(np.array(im) / 255.0).permute(2, 0, 1).float().cpu()
            self.logger.experiment.add_image(
                f"{stage}/{self.global_step}_{batch_idx}",
                log_image,
                global_step=self.global_step,
            )
        
        return im_fp

    @torch.no_grad()
    @rank_zero_only
    def _log_to_wandb(self, stage, output_images_fp: Optional[List[Any]] = None):
        import wandb
        
        captions, images = [], []
        if output_images_fp is None:
            # get images which start with {stage}_{self.global_step} from self.save_dir
            for f in os.listdir(self.save_dir):
                if f.startswith(f"{stage}_{self.global_step}") and f.endswith(".png"):
                    captions.append(f)
                    images.append(os.path.join(self.save_dir, f))
        else:
            images = output_images_fp
            captions = [os.basename(fp) for fp in output_images_fp]

        self.logger.experiment.log(
            {
                stage: [
                    wandb.Image(im_fp, caption=caption)
                    for im_fp, caption in zip(images, captions)
                ]
            },
            step=self.global_step,
        )
