"""
Copyright (c) Facebook, Inc. and its affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

from collections import defaultdict

import sys
sys.path.insert(0, '../external/fastMRI')

import numpy as np
import pytorch_lightning as pl
import torch
import torchvision
from torch.utils.data import DistributedSampler, DataLoader
import torch.distributed as dist


from training_utils.mri_data import SliceData
from training_utils.volume_sampler import VolumeSampler
from training_utils import evaluate
from training_utils.evaluate import NMSE, PSNR, SSIM, DistributedMetricSum

from common.utils import save_reconstructions


class MRIModel(pl.LightningModule):
    """
    Abstract super class for Deep Learning based reconstruction models.
    This is a subclass of the LightningModule class from pytorch_lightning, with
    some additional functionality specific to fastMRI:
        - fastMRI data loaders
        - Evaluating reconstructions
        - Visualization
        - Saving test reconstructions

    To implement a new reconstruction model, inherit from this class and implement the
    following methods:
        - train_data_transform, val_data_transform, test_data_transform:
            Create and return data transformer objects for each data split
        - training_step, validation_step, test_step:
            Define what happens in one step of training, validation and testing respectively
        - configure_optimizers:
            Create and return the optimizers
    Other methods from LightningModule can be overridden as needed.
    """

    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.training_loader = None
        
        self.NMSE = DistributedMetricSum(name="NMSE")
        self.SSIM = DistributedMetricSum(name="SSIM")
        self.PSNR = DistributedMetricSum(name="PSNR")
        self.ValLoss = DistributedMetricSum(name="ValLoss")
        self.TotExamples = DistributedMetricSum(name="TotExamples")

    def _create_data_loader(self, data_transform, data_partition, sample_rate=None):
        num_workers = 4
        def worker_init(worker_id):
            np.random.seed(int(torch.initial_seed())%(2**32-1) + dist.get_rank()*num_workers)
            
        sample_rate = sample_rate or self.hparams.sample_rate
        
        if data_partition == 'eval':
            data_partition = 'val'
        
        dataset = SliceData(
            root=self.hparams.data_path / f'{self.hparams.challenge}_{data_partition}',
            transform=data_transform,
            sample_rate=sample_rate,
            challenge=self.hparams.challenge,
            contrast_type=self.hparams.contrast_type
        )

        is_train = (data_partition == 'train')
        if is_train:
            sampler = DistributedSampler(dataset)
        else:
            sampler = VolumeSampler(dataset)

        return DataLoader(
            dataset=dataset,
            batch_size=self.hparams.batch_size,
            num_workers=num_workers,
            pin_memory=False,
            drop_last=is_train,
            sampler=sampler,
            worker_init_fn=worker_init
        )

    def train_data_transform(self):
        raise NotImplementedError

    def train_dataloader(self):
        self.training_loader = self._create_data_loader(self.train_data_transform(), 
                                                        data_partition='train',
                                                        sample_rate=self.hparams.sample_rate)
        return self.training_loader

    def val_data_transform(self):
        raise NotImplementedError

    def val_dataloader(self):
        return self._create_data_loader(self.val_data_transform(), 
                                        data_partition='val',
                                        sample_rate=self.hparams.val_sample_rate)

    def test_data_transform(self):
        raise NotImplementedError

    def test_dataloader(self):
        return self._create_data_loader(self.test_data_transform(), data_partition='eval', sample_rate=1.)

    def _visualize(self, val_outputs, val_targets):
        def _normalize(image):
            image = image[np.newaxis]
            image = image - image.min()
            return image / image.max()

        def _save_image(image, tag):
            grid = torchvision.utils.make_grid(torch.Tensor(image), nrow=4, pad_value=1)
            self.logger.experiment.add_image(tag, grid, self.global_step)

        # only process first size to simplify visualization.
        visualize_size = val_outputs[0].shape
        val_outputs = [x for x in val_outputs if x.shape == visualize_size]
        val_targets = [x for x in val_targets if x.shape == visualize_size]

        num_logs = len(val_outputs)
        assert num_logs == len(val_targets)

        num_viz_images = 16
        step = (num_logs + num_viz_images - 1) // num_viz_images
        outputs, targets = [], []

        for i in range(0, num_logs, step):
            outputs.append(_normalize(val_outputs[i]))
            targets.append(_normalize(val_targets[i]))

        outputs = np.stack(outputs)
        targets = np.stack(targets)
        _save_image(targets, "Target")
        _save_image(outputs, "Reconstruction")
        _save_image(np.abs(targets - outputs), "Error")

    def validation_step_end(self, val_logs):
        device = val_logs["output"].device
        # move to CPU to save GPU memory
        val_logs = {key: value.cpu() for key, value in val_logs.items()}
        val_logs["device"] = device

        return val_logs
    
    def test_step_end(self, test_logs):
        return self.validation_step_end(test_logs)
    
    def validation_epoch_end(self, val_logs):
        assert val_logs[0]["output"].ndim == 3
        device = val_logs[0]["device"]

        # run the visualizations
        self._visualize(
            val_outputs=np.concatenate([x["output"].numpy() for x in val_logs]),
            val_targets=np.concatenate([x["target"].numpy() for x in val_logs]),
        )

        # aggregate losses
        losses = []
        outputs = defaultdict(list)
        targets = defaultdict(list)

        for val_log in val_logs:
            losses.append(val_log["val_loss"])
            for i, (fname, slice_ind) in enumerate(
                zip(val_log["fname"], val_log["slice"])
            ):
                # need to check for duplicate slices
                if slice_ind not in [s for (s, _) in outputs[int(fname)]]:
                    outputs[int(fname)].append((int(slice_ind), val_log["output"][i]))
                    targets[int(fname)].append((int(slice_ind), val_log["target"][i]))

        # handle aggregation for distributed case with pytorch_lightning metrics
        metrics = dict(val_loss=0, nmse=0, ssim=0, psnr=0)
        for fname in outputs:
            output = torch.stack([out for _, out in sorted(outputs[fname])], dim=0).numpy()
            target = torch.stack([tgt for _, tgt in sorted(targets[fname])], dim=0).numpy()
            metrics["nmse"] = metrics["nmse"] + evaluate.nmse(target, output)
            metrics["ssim"] = metrics["ssim"] + evaluate.ssim(target, output)
            metrics["psnr"] = metrics["psnr"] + evaluate.psnr(target, output)

        # currently ddp reduction requires everything on CUDA, so we'll do this manually
        metrics["nmse"] = self.NMSE(torch.tensor(metrics["nmse"]).to(device))
        metrics["ssim"] = self.SSIM(torch.tensor(metrics["ssim"]).to(device))
        metrics["psnr"] = self.PSNR(torch.tensor(metrics["psnr"]).to(device))
        metrics["val_loss"] = self.ValLoss(torch.sum(torch.stack(losses)).to(device))

        num_examples = torch.tensor(len(outputs)).to(device)
        tot_examples = self.TotExamples(num_examples)

        log_metrics = {
            f"metrics/{metric}": values / tot_examples
            for metric, values in metrics.items()
        }
        metrics = {metric: values / tot_examples for metric, values in metrics.items()}
        print(metrics)
        return dict(log=log_metrics, **metrics)

    def test_epoch_end(self, test_logs):
        return self.validation_epoch_end(test_logs)

