from collections import namedtuple, OrderedDict
import copy
import math
import sys
import copy
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair, _quadruple
 

def make_list(var, n=None):
    var = var if isinstance(var, list) else [var]
    if n is None:
        return var
    else:
        assert len(var) == 1 or len(var) == n, 'Wrong list length for make_list'
        return var * n if len(var) == 1 else var

def same_shape(shape1, shape2):
    if len(shape1) != len(shape2):
        return False
    for i in range(len(shape1)):
        if shape1[i] != shape2[2]:
            return False
    return True

def inv2depth(disp):
    if isinstance(disp, tuple) or isinstance(disp, list):
        return [inv2depth(item) for item in disp]
    return 1. / disp.clamp(min=1e-6)

def reprojection_loss(input_color, warped_input, ssim=None):
    ssim_loss = 0.
    abs_diff = torch.abs(input_color - warped_input)
    l1_loss = abs_diff.mean(1, True)
    if ssim is not None:
        ssim_loss = ssim_loss + ssim(input_color, warped_input).mean(1, True)
    return 0.85 * ssim_loss + 0.15 * l1_loss


def depth_loss(pan_processed_output, depth_output):
    dummy = depth_output.detach().squeeze(0)                    # for using pseudo-label
    dummy = median_filter(dummy, 5)

    loss = torch.tensor([0.], requires_grad=True).to(depth_output.device)
    selected_classes = [6, 7, 11, 12, 13, 14, 15, 16, 17, 18]
    for pan_out, depth_out in zip(pan_processed_output, depth_output):
        seg_output_map = pan_out['segmentation']
        seg_output_map = seg_output_map.unsqueeze(0).unsqueeze(0)
        for segment in pan_out['segments_info']:
            if segment['label_id'] in selected_classes:
                mask = seg_output_map == segment['id']
                mask = mask.squeeze(0)

                object_norm = depth_out[mask]
                dummy_depth = dummy[mask]
                loss += torch.square(object_norm - dummy_depth)
    return loss.mean()

def smooth_loss(input_color, disp_out):
    mean_disp = disp_out.mean()
    mean_norm_depth = disp_out / (mean_disp + 1e-7)
    
    dx_dI = torch.mean(torch.abs(input_color[..., 1:] - input_color[..., :-1]), 1, keepdim=True)
    dy_dI = torch.mean(torch.abs(input_color[:, :, 1:, :] - input_color[:, :, :-1, :]), 1, keepdim=True)
    dx_dD = torch.abs(mean_norm_depth[..., 1:] - mean_norm_depth[..., :-1])
    dy_dD = torch.abs(mean_norm_depth[:, :, 1:, :] - mean_norm_depth[:, :, :-1, :])

    return (dx_dD * torch.exp(-1 * dx_dI)).mean() + (dy_dD * torch.exp(-1 * dy_dI)).mean()

def edge_loss(rgb, pred_disp, filter_size=1):
    im_gray = torch.mean(rgb, dim=1)
    
    edge_map = second_order_edge(im_gray, filter_size)
    disp_edge = second_order_edge(pred_disp, filter_size)
    edge_norm = edge_map / (edge_map.max() + 1e-6)
    edge_map *= edge_norm

    loss = torch.abs(edge_map - disp_edge)

    return loss

def second_order_edge(array, k):
    assert isinstance(k, int), f'kernel size should be int but take {type(k)}'
    kk = int(k * 2)
    padded_array = F.pad(array, (k, k, k, k), mode='reflect')
    laplacian = (
        padded_array[..., k:-k, kk:] +  
        padded_array[..., k:-k, :-kk] +
        padded_array[..., kk:, k:-k] +
        padded_array[..., :-kk, k:-k] -
        4 * padded_array[..., k:-k, k:-k]
    )
    return torch.abs(laplacian)

def median_filter(pred_disp, kernel_size):
    """apply median filter on (H, W) sized depth estimates"""
    dvc = pred_disp.device
    if not isinstance(pred_disp, np.ndarray):
        pred_disp = pred_disp.squeeze()
        pred_disp = pred_disp.cpu().numpy()
    if kernel_size % 2 == 0:
        kernel_size += 1
    filtered_disp = cv2.medianBlur(pred_disp, kernel_size)
    filtered_disp = torch.from_numpy(filtered_disp)
    filtered_disp = filtered_disp.unsqueeze(0).to(dvc)
    return filtered_disp

def compute_loss(input_color, output, pan_color_processed, opt, idx):
    total_loss = 0.
    for s in opt.scales:
        loss = 0.

        # depth loss
        if opt.depth_loss:
            disp_ness = depth_loss(pan_color_processed, output[('disp', s)])
            loss += disp_ness / 2**s

        # edge loss
        if opt.edge_loss:
            edge_ness = opt.lamb * edge_loss(input_color, output[('disp', s)], filter_size=opt.filter_size).mean()
            loss += opt.lamb * edge_ness / 2**s
        total_loss += loss
    return total_loss


def predict_tta(model, image, min_depth, max_depth, device):
    pred = model(image)[-1]
    #     pred = utils.depth_norm(pred)
    #     pred = nn.functional.interpolate(pred, depth.shape[-2:], mode='bilinear', align_corners=True)
    #     pred = np.clip(pred.cpu().numpy(), 10, 1000)/100.
    pred = np.clip(pred.detach().cpu().numpy(), min_depth, max_depth)

    image = torch.Tensor(np.array(image.cpu().numpy())[..., ::-1].copy()).to(device)

    pred_lr = model(image)[-1]
    #     pred_lr = utils.depth_norm(pred_lr)
    #     pred_lr = nn.functional.interpolate(pred_lr, depth.shape[-2:], mode='bilinear', align_corners=True)
    #     pred_lr = np.clip(pred_lr.cpu().numpy()[...,::-1], 10, 1000)/100.
    pred_lr = np.clip(pred_lr.cpu().numpy()[..., ::-1], min_depth, max_depth)
    final = 0.5 * (pred + pred_lr)
    final = nn.functional.interpolate(torch.Tensor(final), image.shape[-2:], mode='bilinear', align_corners=True)
    return torch.Tensor(final)

def get_seg_image(mask, image):
    # a label and all meta information
    Label = namedtuple( 'Label' , [

        'name'        , # The identifier of this label, e.g. 'car', 'person', ... .
                        # We use them to uniquely name a class

        'id'          , # An integer ID that is associated with this label.
                        # The IDs are used to represent the label in ground truth images
                        # An ID of -1 means that this label does not have an ID and thus
                        # is ignored when creating ground truth images (e.g. license plate).
                        # Do not modify these IDs, since exactly these IDs are expected by the
                        # evaluation server.

        'trainId'     , # Feel free to modify these IDs as suitable for your method. Then create
                        # ground truth images with train IDs, using the tools provided in the
                        # 'preparation' folder. However, make sure to validate or submit results
                        # to our evaluation server using the regular IDs above!
                        # For trainIds, multiple labels might have the same ID. Then, these labels
                        # are mapped to the same class in the ground truth images. For the inverse
                        # mapping, we use the label that is defined first in the list below.
                        # For example, mapping all void-type classes to the same ID in training,
                        # might make sense for some approaches.
                        # Max value is 255!

        'category'    , # The name of the category that this label belongs to

        'categoryId'  , # The ID of this category. Used to create ground truth images
                        # on category level.

        'hasInstances', # Whether this label distinguishes between single instances or not

        'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
                        # during evaluations or not

        'color'       , # The color of this label
        ] )


    #--------------------------------------------------------------------------------
    # A list of all labels
    #--------------------------------------------------------------------------------

    # Please adapt the train IDs as appropriate for your approach.
    # Note that you might want to ignore labels with ID 255 during training.
    # Further note that the current train IDs are only a suggestion. You can use whatever you like.
    # Make sure to provide your results using the original IDs and not the training IDs.
    # Note that many IDs are ignored in evaluation and thus you never need to predict these!

    labels = [
        #       name                     id    trainId   category            catId     hasInstances   ignoreInEval   color
        Label(  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
        Label(  'ego vehicle'          ,  1 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
        Label(  'rectification border' ,  2 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
        Label(  'out of roi'           ,  3 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
        Label(  'static'               ,  4 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
        Label(  'dynamic'              ,  5 ,      255 , 'void'            , 0       , False        , True         , (111, 74,  0) ),
        Label(  'ground'               ,  6 ,      255 , 'void'            , 0       , False        , True         , ( 81,  0, 81) ),
        Label(  'road'                 ,  7 ,        0 , 'flat'            , 1       , False        , False        , (128, 64,128) ),
        Label(  'sidewalk'             ,  8 ,        1 , 'flat'            , 1       , False        , False        , (244, 35,232) ),
        Label(  'parking'              ,  9 ,      255 , 'flat'            , 1       , False        , True         , (250,170,160) ),
        Label(  'rail track'           , 10 ,      255 , 'flat'            , 1       , False        , True         , (230,150,140) ),
        Label(  'building'             , 11 ,        2 , 'construction'    , 2       , False        , False        , ( 70, 70, 70) ),
        Label(  'wall'                 , 12 ,        3 , 'construction'    , 2       , False        , False        , (102,102,156) ),
        Label(  'fence'                , 13 ,        4 , 'construction'    , 2       , False        , False        , (190,153,153) ),
        Label(  'guard rail'           , 14 ,      255 , 'construction'    , 2       , False        , True         , (180,165,180) ),
        Label(  'bridge'               , 15 ,      255 , 'construction'    , 2       , False        , True         , (150,100,100) ),
        Label(  'tunnel'               , 16 ,      255 , 'construction'    , 2       , False        , True         , (150,120, 90) ),
        Label(  'pole'                 , 17 ,        5 , 'object'          , 3       , False        , False        , (153,153,153) ),
        Label(  'polegroup'            , 18 ,      255 , 'object'          , 3       , False        , True         , (153,153,153) ),
        Label(  'traffic light'        , 19 ,        6 , 'object'          , 3       , False        , False        , (250,170, 30) ),
        Label(  'traffic sign'         , 20 ,        7 , 'object'          , 3       , False        , False        , (220,220,  0) ),
        Label(  'vegetation'           , 21 ,        8 , 'nature'          , 4       , False        , False        , (107,142, 35) ),
        Label(  'terrain'              , 22 ,        9 , 'nature'          , 4       , False        , False        , (152,251,152) ),
        Label(  'sky'                  , 23 ,       10 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),
        Label(  'person'               , 24 ,       11 , 'human'           , 6       , True         , False        , (220, 20, 60) ),
        Label(  'rider'                , 25 ,       12 , 'human'           , 6       , True         , False        , (255,  0,  0) ),
        Label(  'car'                  , 26 ,       13 , 'vehicle'         , 7       , True         , False        , (  0,  0,142) ),
        Label(  'truck'                , 27 ,       14 , 'vehicle'         , 7       , True         , False        , (  0,  0, 70) ),
        Label(  'bus'                  , 28 ,       15 , 'vehicle'         , 7       , True         , False        , (  0, 60,100) ),
        Label(  'caravan'              , 29 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0, 90) ),
        Label(  'trailer'              , 30 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0,110) ),
        Label(  'train'                , 31 ,       16 , 'vehicle'         , 7       , True         , False        , (  0, 80,100) ),
        Label(  'motorcycle'           , 32 ,       17 , 'vehicle'         , 7       , True         , False        , (  0,  0,230) ),
        Label(  'bicycle'              , 33 ,       18 , 'vehicle'         , 7       , True         , False        , (119, 11, 32) ),
        Label(  'license plate'        , -1 ,       -1 , 'vehicle'         , 7       , False        , True         , (  0,  0,142) ),
    ]

    color_palette = [np.array(label[-1]) for label in labels]
    color_palette = {label[2]: np.array(label[-1]) for label in labels}
    color_palette[255] = np.array([0, 0, 0])

    color_seg = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    for label, color in color_palette.items():
        color_seg[mask == label, :] = color
    color_seg = color_seg[..., ::-1]  # convert to BGR
    img = np.round(image * 255) * 0.5 + color_seg * 0.5  # plot the image with the segmentation map
    img = img.astype(np.uint8)
    return img


def even_sweep(encoder, opt):
    norm_params = []
    core_params = []
    for module in encoder.modules():
        if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
            if opt.train_bn:
                module.train()
            for param in module.parameters():
                param.requires_grad_(True)
                norm_params.append(param)
        elif isinstance(module, nn.LayerNorm):
            module.train()
            for param in module.parameters():
                param.requires_grad_(True)
                norm_params.append(param)
        else:
            module.train()
            for param in module.parameters():
                param.requires_grad_(True)
                core_params.append(param)

    norm_idx = int(len(norm_params) * opt.ada_ratio)
    core_idx = int(len(core_params) * opt.ada_ratio)
    adapt_params = norm_params[-norm_idx:]
    adapt_params.extend(core_params[-core_idx:])
    return adapt_params

def norm_sweep(encoder, opt):
    norm_params = []
    for module in encoder.modules():
        if not opt.ViT:
            if isinstance(module, nn.BatchNorm2d):
                if opt.train_bn:
                    module.train()
                for param in module.parameters(): 
                    param.requires_grad_(True)
                    norm_params.append(param)
        else:                            
            if isinstance(module, nn.LayerNorm):
                module.train()
                for param in module.parameters(): 
                    param.requires_grad_(True)
                    norm_params.append(param)
            if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
                if opt.train_bn:
                    module.train()
                for param in module.parameters(): 
                    param.requires_grad_(True)
                    norm_params.append(param)
        
    target_idx = int(len(norm_params) * opt.ada_ratio)
    adapt_params = norm_params[-target_idx:]
    return adapt_params

def dynamic_sweep(encoder, opt):
    core_group = (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)
    norm_params = []
    core_params = []
    for module in encoder.modules():
        if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
            if opt.train_bn:
                module.train()
            for param in module.parameters():
                param.requires_grad_(True)
                norm_params.append(param)
        elif isinstance(module, nn.LayerNorm):
            module.train()
            for param in module.parameters(): 
                    param.requires_grad_(True)
                    norm_params.append(param)
        elif isinstance(module, core_group):
            module.train()
            for param in module.parameters():
                param.requires_grad_(True)
                core_params.append(param)

    norm_idx = int(len(norm_params) * opt.ada_ratio)
    adapt_params = norm_params[-norm_idx:] 
    target_core_params = None
    if len(core_params) > 0:
        core_idx = int(len(core_params) * (1 - opt.ada_ratio))
        target_core_params = core_params[-core_idx:]

    if target_core_params is not None and len(target_core_params) > 0:
        adapt_params.extend(target_core_params)
    
    return adapt_params

def sweep_params(encoder, opt):
    encoder.requires_grad_(False)
    if opt.tta_params == "even":
        return even_sweep(encoder, opt)

    elif opt.tta_params == "norm":
        return norm_sweep(encoder, opt)
        
    elif opt.tta_params == 'dynamic':
        return dynamic_sweep(encoder, opt)