import torch
import torch.nn as nn
import torch.nn.functional as F
from opencood.loss.contrastive_loss import ContrastiveLoss
from opencood.loss.point_pillar_depth_loss import PointPillarDepthLoss
from opencood.loss.point_pillar_pyramid_loss import PointPillarPyramidLoss
from opencood.tools.feature_show import feature_show


class PnpdaLoss(PointPillarDepthLoss):
    def __init__(self, args):
        super().__init__(args)
        self.contrastive = ContrastiveLoss(args['contrastive'])
        self.pyramid_loss = PointPillarPyramidLoss(args)
        self.stage = args['stage']
        self.loss_dict = {}

    def forward(self, output_dict, label_dict, label_dict_single, suffix=""):
        """
        Parameters
        ----------
        output_dict: 模型输出
        target_dict: 后处理输出
        """
        if self.stage == 'rep':
            contrastive_loss, _ = self.contrastive(output_dict['query_for_train'],
                                        output_dict['key_for_train'],
                                        label_dict_single['pos_region_ranges'],
                                        output_dict['adapt_agent_idx']
                                        )
            
            self.loss_dict.update(
                {
                    "contra_loss": contrastive_loss
                }
            )
            total_loss = contrastive_loss
        
        """FineTune阶段, 只计算检测损失"""
        if self.stage == 'ft': 
            fusion_method = output_dict['fusion_method']
            
            if suffix == '_single':
                if fusion_method == 'pyramid':
                    sup_single_loss = self.pyramid_loss(output_dict, label_dict_single, suffix)                
                else:
                    sup_single_loss = super().forward(output_dict, label_dict_single, suffix)
                    
                self.loss_dict.update({"sup_single_loss": sup_single_loss})   
                return sup_single_loss


            if fusion_method == 'pyramid':
                det_loss = self.pyramid_loss(output_dict, label_dict, suffix)               
            else:
                det_loss = super().forward(output_dict, label_dict)

            self.loss_dict.update(
                {
                    "det_loss": det_loss
                }
            )
            total_loss = det_loss

        self.total_loss = total_loss
        return total_loss

    def logging(self, epoch, batch_id, batch_len, writer = None, suffix="", pbar=None):
        
        writer.add_scalar('Total_loss', self.total_loss.item(),
                            epoch*batch_len + batch_id)
        
        for k, v in self.loss_dict.items():
            if not isinstance(v, torch.Tensor):
                writer.add_scalar(k, v, epoch*batch_len + batch_id)
            else: 
                writer.add_scalar(k, v.item(), epoch*batch_len + batch_id)
            # print(k, v.item())
           
        print_msg ="[epoch %d][%d/%d]%s, || Loss: %.4f" % \
            (epoch, batch_id + 1, batch_len, suffix, self.total_loss.item())
        for k, v in self.loss_dict.items():
            k = k.replace("_loss", "").capitalize()
            print_msg = print_msg + f" || {k}: {v:.4f}"
        
        if pbar is None:
            print(print_msg)   
        else:
            pbar.set_description(print_msg)


if __name__ == "__main__":
    import os

    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    ego = torch.rand(4, 256, 50, 176)  # .cuda()
    cav = torch.rand(4, 256, 50, 176)  # .cuda()
    mask = torch.rand(4, 50, 50, 176)>0.2
    
    args ={
        'tau': 0.1,
        'max_voxel': 40
    }
    data_dict = {"features_q": ego, "features_k": cav}
    target_dict = {"pos_region_ranges": mask}
    model = ContrastiveLoss(args)
    output = model(ego, cav, mask)
    # print(output)
    # print(output.shape)