import sys
import gc

import torch
from torch.utils.data import DataLoader
import torch.distributed as dist
import torch.nn.functional as F
import numpy as np
import pytorch_lightning as pl
import pytorch_lightning.utilities.distributed as pl_dist
import pl_bolts
import torchmetrics
import MinkowskiEngine as ME

import datasets
from util.metric import per_class_iou
from util.scheduler import PolyLR
from util.collate_fns import SparseCollation
import util.transforms as T


class LitSegmentationModule(pl.LightningModule):
    def __init__(self, model, config, sync_dist=False):
        super().__init__()
        self.save_hyperparameters(config)
        self.dataset = getattr(datasets, self.hparams.dataset)
        self.model = model
        self.best_miou = 0.
        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=self.hparams.ignore_label)
        self.sync_dist = sync_dist
        self.confusion_matrix = torchmetrics.ConfusionMatrix(
            num_classes=self.dataset.NUM_CLASSES,
            compute_on_step=False,
            dist_sync_on_step=sync_dist
        )

    def train_dataloader(self):
        if self.hparams.dataset == "SemanticKITTIDataset":
            transforms = []
            transforms.append(T.RandomCrop(
                    x=self.hparams.crop_size,
                    y=self.hparams.crop_size,
                    z=self.hparams.crop_size,
                    min_length=self.hparams.min_length
                ))
            transform = T.Compose(transforms)
            dset = self.dataset('train', transform, self.hparams)
        else:
            assert self.dataset.USE_RGB
            if self.hparams.overfit:
                transform = T.Compose([T.NormalizeColor()])
                dset = self.dataset('overfit', transform, config=self.hparams)
            else:
                transforms = []
                transforms.append(T.RandomRotation())
                transforms.append(T.RandomCrop(
                    x=self.hparams.crop_size,
                    y=self.hparams.crop_size,
                    z=self.hparams.crop_size,
                    min_length=self.hparams.min_length
                ))
                transforms.append(T.RandomAffine(upright_axis="z", application_ratio=0.7))
                transforms.append(T.CoordinateDropout())
                transforms.append(T.ChromaticTranslation())
                transforms.append(T.ChromaticJitter(std=0.01, application_ratio=0.7))
                transforms.append(T.RandomHorizontalFlip(upright_axis="z"))
                transforms.append(T.RandomTranslation())
                transforms.append(T.ElasticDistortion(distortion_params=[(4, 16)], application_ratio=0.7))
                transforms.append(T.NormalizeColor())
                transform = T.Compose(transforms)
                dset = self.dataset('train', transform, self.hparams)
        return DataLoader(
            dataset=dset,
            batch_size=self.hparams.batch_size,
            shuffle=True,
            num_workers=self.hparams.num_workers,
            collate_fn=SparseCollation(self.hparams.max_num_points)
        )

    def val_dataloader(self):
        if self.hparams.dataset == "SemanticKITTIDataset":
            dset = self.dataset('val', None, self.hparams)
        else:
            assert self.dataset.USE_RGB
            transform = T.Compose([T.NormalizeColor()])
            if self.hparams.overfit:
                dset = self.dataset('overfit', transform, config=self.hparams)
            else:
                dset = self.dataset('val', transform, config=self.hparams)
        self.classnames = dset.get_classnames()
        return DataLoader(
            dataset=dset,
            batch_size=self.hparams.val_batch_size,
            shuffle=False,
            num_workers=self.hparams.num_workers,
            collate_fn=SparseCollation(sys.maxsize),
            drop_last=False
        )

    def test_dataloader(self):
        return self.val_dataloader()

    def forward(self, *args):
        return self.model(*args)

    def training_step(self, batch, batch_idx):
        gc.collect()
        torch.cuda.empty_cache()
        
        coords, feats, labels = batch
        in_field = ME.TensorField(
            features=feats,
            coordinates=coords,
            quantization_mode=self.model.Q_MODE
        )
        pred = self(in_field)
        loss = self.criterion(pred, labels)
        self.log('train_loss', loss.item() / dist.get_world_size(), on_step=True, on_epoch=False, sync_dist=self.sync_dist, logger=True)
        num_points = torch.tensor(len(feats), device=feats.device)

        return {"loss": loss, "num_points": num_points}

    def training_step_end(self, batch_parts):
        if self.sync_dist:
            gathered_num_points = pl_dist.gather_all_tensors(batch_parts["num_points"])
            batch_parts["loss"] *= batch_parts["num_points"] / sum(gathered_num_points) * len(gathered_num_points)
        return batch_parts

    def validation_step(self, batch, batch_idx):
        gc.collect()
        torch.cuda.empty_cache()
        
        coords, feats, labels = batch
        in_field = ME.TensorField(
            features=feats,
            coordinates=coords,
            quantization_mode=self.model.Q_MODE
        )
        pred = self(in_field)
        loss = self.criterion(pred, labels)
        self.log('val_loss', loss.item(), on_step=False, on_epoch=True, sync_dist=self.sync_dist, logger=True)
        pred = pred.argmax(dim=1, keepdim=False)
        valid_mask = labels != self.hparams.ignore_label
        self.confusion_matrix(pred[valid_mask].detach(), labels[valid_mask].detach())

        return loss

    def validation_epoch_end(self, outputs):
        confusion_matrix = self.confusion_matrix.compute().cpu().numpy()
        self.confusion_matrix.reset()
        ious = per_class_iou(confusion_matrix) * 100
        accs = confusion_matrix.diagonal() / confusion_matrix.sum(1) * 100
        # overall logging
        miou = np.nanmean(ious)
        self.best_miou = max(miou, self.best_miou)
        self.log(f'val_best_mIoU', self.best_miou)
        self.log(f'val_mIoU', miou)
        self.log(f'val_mAcc', np.nanmean(accs))

        # class-wise logging
        for class_id in range(self.dataset.NUM_CLASSES):
            self.log(f'val_IoU/{self.classnames[class_id]}', ious[class_id])
            self.log(f'val_Acc/{self.classnames[class_id]}', accs[class_id])

        gc.collect()
        torch.cuda.empty_cache()

    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        return self.validation_epoch_end(outputs)

    def configure_optimizers(self):
        if self.hparams.optimizer == 'SGD':
            if self.hparams.stem_only:
                for name, param in self.model.named_parameters():
                    if 'stem' not in name:
                        param.requires_grad = False
            optimizer = torch.optim.SGD(
                self.model.parameters(),
                lr=self.hparams.lr,
                momentum=self.hparams.momentum,
                weight_decay=self.hparams.weight_decay
            )
        else:
            raise NotImplementedError
        if self.hparams.scheduler == 'MultiStepLR':
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=self.hparams.milestones,
                gamma=self.hparams.gamma
            )
        elif self.hparams.scheduler == 'PolyLR':
            scheduler = PolyLR(
                optimizer,
                max_iter=self.hparams.max_iter,
                power=self.hparams.poly_power
            )
        elif self.hparams.scheduler == 'LinearWarmCosineLR':
            scheduler = pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR(
                optimizer,
                warmup_epochs=self.hparams.warmup_iter,
                max_epochs=self.hparams.max_iter
            )
        else:
            raise NotImplementedError
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': self.hparams.lr_interval
            }
        }
