import torch
import torch.nn as nn
import pytorch_lightning as pl

from fiery.config import get_cfg
# from fiery.models.fiery_polar import Fiery
from fiery.models.fiery import Fiery
from fiery.losses import ProbabilisticLoss, SpatialRegressionLoss, SegmentationLoss
from fiery.metrics import IntersectionOverUnion, PanopticMetric
from fiery.utils.geometry import cumulative_warp_features_reverse
from fiery.utils.instance import predict_instance_segmentation_and_trajectories
from fiery.utils.visualisation import visualise_output


class TrainingModule(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        
        # see config.py for detailss
        self.hparams = hparams
        # pytorch lightning does not support saving YACS CfgNone
        cfg = get_cfg(cfg_dict=self.hparams)
        self.cfg = cfg
        self.n_classes = len(self.cfg.SEMANTIC_SEG.WEIGHTS)

        # Bird's-eye view extent in meters
        assert self.cfg.LIFT.X_BOUND[1] > 0 and self.cfg.LIFT.Y_BOUND[1] > 0
        self.spatial_extent = (self.cfg.LIFT.X_BOUND[1], self.cfg.LIFT.Y_BOUND[1])

        # Model
        self.model = Fiery(cfg)

        # Losses
        self.losses_fn = nn.ModuleDict()
        self.losses_fn['segmentation'] = SegmentationLoss(
            class_weights=torch.Tensor(self.cfg.SEMANTIC_SEG.WEIGHTS),
            use_top_k=self.cfg.SEMANTIC_SEG.USE_TOP_K,
            top_k_ratio=self.cfg.SEMANTIC_SEG.TOP_K_RATIO,
            future_discount=self.cfg.FUTURE_DISCOUNT,
        )

        # Uncertainty weighting
        self.model.segmentation_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True)

        self.metric_iou_val = IntersectionOverUnion(self.n_classes)

        self.losses_fn['instance_center'] = SpatialRegressionLoss(
            norm=2, future_discount=self.cfg.FUTURE_DISCOUNT
        )
        self.losses_fn['instance_offset'] = SpatialRegressionLoss(
            norm=1, future_discount=self.cfg.FUTURE_DISCOUNT, ignore_index=self.cfg.DATASET.IGNORE_INDEX
        )

        # Uncertainty weighting
        self.model.centerness_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True)
        self.model.offset_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True)
        
        # 添加中间输出的Uncertainty weighting
        # self.model.mid_segmentation_weights1 = nn.Parameter(torch.tensor(0.0), requires_grad=True)
        # self.model.mid_segmentation_weights2 = nn.Parameter(torch.tensor(0.0), requires_grad=True)
        # self.model.mid_segmentation_weights3 = nn.Parameter(torch.tensor(0.0), requires_grad=True)
        
        # self.model.mid_centerness_weights1 = nn.Parameter(torch.tensor(0.0), requires_grad=True)
        # self.model.mid_centerness_weights2 = nn.Parameter(torch.tensor(0.0), requires_grad=True)
        # self.model.mid_centerness_weights3 = nn.Parameter(torch.tensor(0.0), requires_grad=True)
        
        # self.model.mid_offset_weights1 = nn.Parameter(torch.tensor(0.0), requires_grad=True)
        # self.model.mid_offset_weights2 = nn.Parameter(torch.tensor(0.0), requires_grad=True)
        # self.model.mid_offset_weights3 = nn.Parameter(torch.tensor(0.0), requires_grad=True)

        self.metric_panoptic_val = PanopticMetric(n_classes=self.n_classes)

        if self.cfg.INSTANCE_FLOW.ENABLED:
            self.losses_fn['instance_flow'] = SpatialRegressionLoss(
                norm=1, future_discount=self.cfg.FUTURE_DISCOUNT, ignore_index=self.cfg.DATASET.IGNORE_INDEX
            )
            # Uncertainty weighting
            self.model.flow_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True)

        if self.cfg.PROBABILISTIC.ENABLED:
            self.losses_fn['probabilistic'] = ProbabilisticLoss()

        self.training_step_count = 0

    def shared_step(self, batch, is_train):
        image = batch['image']
        intrinsics = batch['intrinsics']
        extrinsics = batch['extrinsics']
        lidar2imgs = batch['lidar2imgs']
        future_egomotion = batch['future_egomotion']
        # import pdb; pdb.set_trace()
        # Warp labels
        labels, future_distribution_inputs = self.prepare_future_labels(batch)
        # import pdb; pdb.set_trace()
        # Forward pass
        # output, inter_seg, inter_instance_offset, inter_instance_center = self.model(
        #     image, intrinsics, extrinsics, lidar2imgs, future_egomotion, future_distribution_inputs
        # )
        output = self.model(
            image, intrinsics, extrinsics, lidar2imgs, future_egomotion, future_distribution_inputs
        )

        #####
        # Loss computation
        #####
        # import pdb; pdb.set_trace()
        loss = {}
        # loss['segmentation'] = 0.0
        # loss['instance_center'] = 0.0
        # loss['instance_offset'] = 0.0
        # for i in range(len(inter_seg) - 1):
        #     loss['segmentation'] += 0.5 * self.losses_fn['segmentation'](
        #         inter_seg[i], labels['segmentation'])
        #     loss['instance_center'] += 0.5 * self.losses_fn['instance_center'](
        #         inter_instance_center[i], labels['centerness'])
        #     loss['instance_offset'] += 0.5 * self.losses_fn['instance_offset'](
        #         inter_instance_offset[i], labels['offset'])
            
        # loss['segmentation'] /= (len(inter_seg) - 1)
        # loss['instance_center'] /= (len(inter_seg) - 1)
        # loss['instance_offset'] /= (len(inter_seg) - 1)
        
        segmentation_factor = 1 / torch.exp(self.model.segmentation_weight)
        # loss['segmentation'] = segmentation_factor * (loss['segmentation'] + self.losses_fn['segmentation'](
        #     output['segmentation'], labels['segmentation']))
        loss['segmentation'] = segmentation_factor * self.losses_fn['segmentation'](
            output['segmentation'], labels['segmentation']
        )
        # loss['segmentation'] = self.losses_fn['segmentation'](
        #     output['segmentation'], labels['segmentation']
        # )
        
        # loss['segmentation_uncertainty'] = 1.0 * self.model.segmentation_weight
        loss['segmentation_uncertainty'] = 0.5 * self.model.segmentation_weight

        centerness_factor = 1 / (2*torch.exp(self.model.centerness_weight))
        # loss['instance_center'] = centerness_factor * (loss['instance_center'] + self.losses_fn['instance_center'](
        #     output['instance_center'], labels['centerness']))
        loss['instance_center'] = centerness_factor * self.losses_fn['instance_center'](
            output['instance_center'], labels['centerness']
        )
        # loss['instance_center'] = 10 * self.losses_fn['instance_center'](
        #     output['instance_center'], labels['centerness']
        # )

        offset_factor = 1 / (2*torch.exp(self.model.offset_weight))
        # loss['instance_offset'] = offset_factor * (loss['instance_offset'] + self.losses_fn['instance_offset'](
        #     output['instance_offset'], labels['offset']))
        loss['instance_offset'] = offset_factor * self.losses_fn['instance_offset'](
            output['instance_offset'], labels['offset']
        )
        # loss['instance_offset'] = 0.1 * self.losses_fn['instance_offset'](
        #     output['instance_offset'], labels['offset']
        # )
        # print(loss['segmentation'], loss['instance_center'], loss['instance_offset'])
        # loss['centerness_uncertainty'] = 1.0 * self.model.centerness_weight
        # loss['offset_uncertainty'] = 1.0 * self.model.offset_weight
        loss['centerness_uncertainty'] = 0.5 * self.model.centerness_weight
        loss['offset_uncertainty'] = 0.5 * self.model.offset_weight

        if self.cfg.INSTANCE_FLOW.ENABLED:
            flow_factor = 1 / (2*torch.exp(self.model.flow_weight))
            loss['instance_flow'] = flow_factor * self.losses_fn['instance_flow'](
                output['instance_flow'], labels['flow']
            )

            loss['flow_uncertainty'] = 0.5 * self.model.flow_weight

        if self.cfg.PROBABILISTIC.ENABLED:
            loss['probabilistic'] = self.cfg.PROBABILISTIC.WEIGHT * self.losses_fn['probabilistic'](output)


        # loss['segmentation'] = self.losses_fn['segmentation'](output['segmentation'], labels['segmentation'])
        # loss['instance_center'] = self.losses_fn['instance_center'](output['instance_center'], labels['centerness'])
        # loss['instance_offset'] = self.losses_fn['instance_offset'](output['instance_offset'], labels['offset'])
        # for i in range(len(inter_seg)):
        #     loss['segmentation'] += self.losses_fn['segmentation'](inter_seg[i], labels['segmentation'])
        #     loss['instance_center'] += self.losses_fn['instance_center'](inter_instance_center[i], labels['centerness'])
        #     loss['instance_offset'] += self.losses_fn['instance_offset'](inter_instance_offset[i], labels['offset'])
        # loss['segmentation'] /= (len(inter_seg) + 1)
        # loss['instance_center'] /= (len(inter_seg) + 1)
        # loss['instance_offset'] /= (len(inter_seg) + 1)
        # loss['segmentation'] *= segmentation_factor
        # loss['instance_center'] *= centerness_factor
        # loss['instance_offset'] *= offset_factor
        # loss['segmentation_uncertainty'] = 0.5 * self.model.segmentation_weight
        # loss['centerness_uncertainty'] = 0.5 * self.model.centerness_weight
        # loss['offset_uncertainty'] = 0.5 * self.model.offset_weight
        
        # segmentation_factor1 = 1 / torch.exp(self.model.mid_segmentation_weights1)
        # loss['segmentation1'] = segmentation_factor1 * self.losses_fn['segmentation'](inter_seg[0], labels['segmentation'])
        # loss['segmentation_uncertainty1'] = 0.5 * self.model.mid_segmentation_weights1
        # segmentation_factor2 = 1 / torch.exp(self.model.mid_segmentation_weights2)
        # loss['segmentation2'] = segmentation_factor2 * self.losses_fn['segmentation'](inter_seg[1], labels['segmentation'])
        # loss['segmentation_uncertainty2'] = 0.5 * self.model.mid_segmentation_weights2
        # segmentation_factor3 = 1 / torch.exp(self.model.mid_segmentation_weights3)
        # loss['segmentation3'] = segmentation_factor3 * self.losses_fn['segmentation'](inter_seg[2], labels['segmentation'])
        # loss['segmentation_uncertainty3'] = 0.5 * self.model.mid_segmentation_weights3
        
        # centerness_factor1 = 1 / (2*torch.exp(self.model.mid_centerness_weights1))
        # loss['instance_center1'] = centerness_factor1 * self.losses_fn['instance_center'](inter_instance_center[0], labels['centerness'])
        # loss['centerness_uncertainty1'] = 0.5 * self.model.mid_centerness_weights1
        # centerness_factor2 = 1 / (2*torch.exp(self.model.mid_centerness_weights2))
        # loss['instance_center2'] = centerness_factor2 * self.losses_fn['instance_center'](inter_instance_center[1], labels['centerness'])
        # loss['centerness_uncertainty2'] = 0.5 * self.model.mid_centerness_weights2
        # centerness_factor3 = 1 / (2*torch.exp(self.model.mid_centerness_weights3))
        # loss['instance_center3'] = centerness_factor3 * self.losses_fn['instance_center'](inter_instance_center[2], labels['centerness'])
        # loss['centerness_uncertainty3'] = 0.5 * self.model.mid_centerness_weights3
        
        # offset_factor1 = 1 / (2*torch.exp(self.model.mid_offset_weights1))
        # loss['instance_offset1'] = offset_factor1 * self.losses_fn['instance_offset'](inter_instance_offset[0], labels['offset'])
        # loss['offset_uncertainty1'] = 0.5 * self.model.mid_offset_weights1
        # offset_factor2 = 1 / (2*torch.exp(self.model.mid_offset_weights2))
        # loss['instance_offset2'] = offset_factor2 * self.losses_fn['instance_offset'](inter_instance_offset[1], labels['offset'])
        # loss['offset_uncertainty2'] = 0.5 * self.model.mid_offset_weights2
        # offset_factor3 = 1 / (2*torch.exp(self.model.mid_offset_weights3))
        # loss['instance_offset3'] = offset_factor3 * self.losses_fn['instance_offset'](inter_instance_offset[2], labels['offset'])
        # loss['offset_uncertainty3'] = 0.5 * self.model.mid_offset_weights3
        

        # for i in range(len(inter_seg)):
        #     segmentation_factor = 1 / torch.exp(self.model.mid_segmentation_weights[i])
        #     loss['segmentation' + str(i)] = segmentation_factor * self.losses_fn['segmentation'](inter_seg[i], labels['segmentation'])
        #     loss['segmentation_uncertainty' + str(i)] = 0.5 * self.model.mid_segmentation_weights[i]
            
        #     centerness_factor = 1 / (2*torch.exp(self.model.mid_centerness_weights[i]))
        #     loss['instance_center' + str(i)] = centerness_factor * self.losses_fn['instance_center'](inter_instance_center[i], labels['centerness'])
        #     loss['centerness_uncertainty' + str(i)] = 0.5 * self.model.mid_centerness_weights[i]
            
        #     offset_factor = 1 / (2*torch.exp(self.model.mid_offset_weights[i]))
        #     loss['instance_offset' + str(i)] = offset_factor * self.losses_fn['instance_offset'](inter_instance_offset[i], labels['offset'])
        #     loss['offset_uncertainty' + str(i)] = 0.5 * self.model.mid_offset_weights[i]
        
        # loss['segmentation'] /= (len(inter_seg) + 1)
        # loss['segmentation_uncertainty'] /= (len(inter_seg) + 1)
        # loss['instance_center'] /= (len(inter_seg) + 1)
        # loss['centerness_uncertainty'] /= (len(inter_seg) + 1)
        # loss['instance_offset'] /= (len(inter_seg) + 1)
        # loss['offset_uncertainty'] /= (len(inter_seg) + 1)
        
        # for i in range(len(inter_seg) - 1):
        #     loss['segmentation' + str(i)] += 0.1 * segmentation_factor * self.losses_fn['segmentation'](inter_seg[i], labels['segmentation'])
        #     loss['segmentation_uncertainty' + str(i)] += 0.1 * 0.5 * self.model.segmentation_weight
        #     loss['instance_center' + str(i)] += 0.1 * centerness_factor * self.losses_fn['instance_center'](inter_instance_center[i], labels['centerness'])
        #     loss['centerness_uncertainty' + str(i)] += 0.1 * 0.5 * self.model.centerness_weight
        #     loss['instance_offset' + str(i)] += 0.1 * offset_factor * self.losses_fn['instance_offset'](inter_instance_offset[i], labels['offset'])
        #     loss['offset_uncertainty' + str(i)] += 0.1 * 0.5 * self.model.offset_weight

        
        # loss_inter = 0.
        # for i in range(len(inter_seg)):
        #     loss_inter += segmentation_factor * self.losses_fn['segmentation'](inter_seg[i], batch['segmentation'][:, :self.model.receptive_field])
        #     loss_inter += 0.5 * self.model.segmentation_weight
        #     loss_inter += centerness_factor * self.losses_fn['instance_center'](inter_instance_center[i], batch['centerness'][:, :self.model.receptive_field])
        #     loss_inter += 0.5 * self.model.centerness_weight
        #     loss_inter += offset_factor * self.losses_fn['instance_offset'](inter_instance_offset[i], batch['offset'][:, :self.model.receptive_field])
        #     loss_inter += 0.5 * self.model.offset_weight
        
        # loss['loss_inter'] = loss_inter / len(inter_seg)
        # loss['segmentation'] / (len(inter_seg) + 1)
        # loss['instance_center'] / (len(inter_seg) + 1)
        # loss['instance_offset'] / (len(inter_seg) + 1)
        
        # Metrics
        if not is_train:
            seg_prediction = output['segmentation'].detach()
            seg_prediction = torch.argmax(seg_prediction, dim=2, keepdims=True)
            self.metric_iou_val(seg_prediction, labels['segmentation'])

            pred_consistent_instance_seg = predict_instance_segmentation_and_trajectories(
                output, compute_matched_centers=False
            )

            self.metric_panoptic_val(pred_consistent_instance_seg, labels['instance'])

        return output, labels, loss

    def prepare_future_labels(self, batch):
        labels = {}
        future_distribution_inputs = []

        segmentation_labels = batch['segmentation']
        instance_center_labels = batch['centerness']
        instance_offset_labels = batch['offset']
        instance_flow_labels = batch['flow']
        gt_instance = batch['instance']
        future_egomotion = batch['future_egomotion']

        # Warp labels to present's reference frame
        segmentation_labels = cumulative_warp_features_reverse(
            segmentation_labels[:, (self.model.receptive_field - 1):].float(),
            future_egomotion[:, (self.model.receptive_field - 1):],
            mode='nearest', spatial_extent=self.spatial_extent,
        ).long().contiguous()
        labels['segmentation'] = segmentation_labels
        future_distribution_inputs.append(segmentation_labels)

        # Warp instance labels to present's reference frame
        gt_instance = cumulative_warp_features_reverse(
            gt_instance[:, (self.model.receptive_field - 1):].float().unsqueeze(2),
            future_egomotion[:, (self.model.receptive_field - 1):],
            mode='nearest', spatial_extent=self.spatial_extent,
        ).long().contiguous()[:, :, 0]
        labels['instance'] = gt_instance

        instance_center_labels = cumulative_warp_features_reverse(
            instance_center_labels[:, (self.model.receptive_field - 1):],
            future_egomotion[:, (self.model.receptive_field - 1):],
            mode='nearest', spatial_extent=self.spatial_extent,
        ).contiguous()
        labels['centerness'] = instance_center_labels

        instance_offset_labels = cumulative_warp_features_reverse(
            instance_offset_labels[:, (self.model.receptive_field- 1):],
            future_egomotion[:, (self.model.receptive_field - 1):],
            mode='nearest', spatial_extent=self.spatial_extent,
        ).contiguous()
        labels['offset'] = instance_offset_labels

        future_distribution_inputs.append(instance_center_labels)
        future_distribution_inputs.append(instance_offset_labels)

        if self.cfg.INSTANCE_FLOW.ENABLED:
            instance_flow_labels = cumulative_warp_features_reverse(
                instance_flow_labels[:, (self.model.receptive_field - 1):],
                future_egomotion[:, (self.model.receptive_field - 1):],
                mode='nearest', spatial_extent=self.spatial_extent,
            ).contiguous()
            labels['flow'] = instance_flow_labels

            future_distribution_inputs.append(instance_flow_labels)

        if len(future_distribution_inputs) > 0:
            future_distribution_inputs = torch.cat(future_distribution_inputs, dim=2)

        return labels, future_distribution_inputs

    def visualise(self, labels, output, batch_idx, prefix='train'):
        visualisation_video = visualise_output(labels, output, self.cfg)
        name = f'{prefix}_outputs'
        if prefix == 'val':
            name = name + f'_{batch_idx}'
        self.logger.experiment.add_video(name, visualisation_video, global_step=self.training_step_count, fps=2)

    def training_step(self, batch, batch_idx):
        output, labels, loss = self.shared_step(batch, True)
        self.training_step_count += 1
        for key, value in loss.items():
            self.logger.experiment.add_scalar(key, value, global_step=self.training_step_count)
        if self.training_step_count % self.cfg.VIS_INTERVAL == 0:
            self.visualise(labels, output, batch_idx, prefix='train')
        return sum(loss.values())

    def validation_step(self, batch, batch_idx):
        output, labels, loss = self.shared_step(batch, False)
        for key, value in loss.items():
            self.log('val_' + key, value)

        if batch_idx == 0:
            self.visualise(labels, output, batch_idx, prefix='val')

    def shared_epoch_end(self, step_outputs, is_train):
        # log per class iou metrics
        class_names = ['background', 'dynamic']
        if not is_train:
            scores = self.metric_iou_val.compute()
            for key, value in zip(class_names, scores):
                self.logger.experiment.add_scalar('val_iou_' + key, value, global_step=self.training_step_count)
            self.metric_iou_val.reset()

        if not is_train:
            scores = self.metric_panoptic_val.compute()
            for key, value in scores.items():
                for instance_name, score in zip(['background', 'vehicles'], value):
                    if instance_name != 'background':
                        self.logger.experiment.add_scalar(f'val_{key}_{instance_name}', score.item(),
                                                          global_step=self.training_step_count)
            self.metric_panoptic_val.reset()

        self.logger.experiment.add_scalar('segmentation_weight',
                                          1 / (torch.exp(self.model.segmentation_weight)),
                                          global_step=self.training_step_count)
        self.logger.experiment.add_scalar('centerness_weight',
                                          1 / (2 * torch.exp(self.model.centerness_weight)),
                                          global_step=self.training_step_count)
        self.logger.experiment.add_scalar('offset_weight', 1 / (2 * torch.exp(self.model.offset_weight)),
                                          global_step=self.training_step_count)
        if self.cfg.INSTANCE_FLOW.ENABLED:
            self.logger.experiment.add_scalar('flow_weight', 1 / (2 * torch.exp(self.model.flow_weight)),
                                              global_step=self.training_step_count)

    def training_epoch_end(self, step_outputs):
        self.shared_epoch_end(step_outputs, True)

    def validation_epoch_end(self, step_outputs):
        self.shared_epoch_end(step_outputs, False)

    def configure_optimizers(self):
        # params = self.model.parameters()
        # optimizer = torch.optim.Adam(
        #     params, lr=self.cfg.OPTIMIZER.LR, weight_decay=self.cfg.OPTIMIZER.WEIGHT_DECAY
        # )

        # return optimizer
        params = self.model.parameters()
        # optimizer = torch.optim.Adam(
        #     params, lr=self.cfg.OPTIMIZER.LR, weight_decay=self.cfg.OPTIMIZER.WEIGHT_DECAY
        # )
        optimizer = torch.optim.AdamW(
            params, lr=self.cfg.OPTIMIZER.LR, weight_decay=self.cfg.OPTIMIZER.WEIGHT_DECAY
        )
        
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.cfg.SCHEDULER.MAX_LR, total_steps=self.cfg.SCHEDULER.TOTAL_STEPS, pct_start=self.cfg.SCHEDULER.PCT_START,\
            cycle_momentum=self.cfg.SCHEDULER.CYCLE_MOMENTUM, div_factor=self.cfg.SCHEDULER.DIV_FACTOR, final_div_factor=self.cfg.SCHEDULER.FINAL_DIV_FACTOR)

        # return optimizer
        return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}]
