from __future__ import absolute_import, division, print_function
import sys

import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3'
import json

import torch.utils
import torch.utils.data
import datasets.waymo
from torchvision.utils import save_image

import cv2
import datetime
import numpy as np
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from layers import disp_to_depth, SSIM
from utils import readlines
from options import MonodepthOptions
import datasets
from models import model_factory
from data_utils import dataset_factory
import pdb
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
import warnings
warnings.filterwarnings("ignore", category=UserWarning) 
from custom_utils import depth_loss, edge_loss, sweep_params, compute_loss

cv2.setNumThreads(0)  # This speeds up evaluation 5x on our unix systems (OpenCV 3.3.1)


splits_dir = os.path.join("splits")

STEREO_SCALE_FACTOR = 5.4
# WEATHER_ORDER = ["rainy", "sunny"]
# WEATHER_ORDER = ["foggy", "rainy", "sunny"]
WEATHER_ORDER = ["cloudy", "foggy", "rainy", "sunny"]


def compute_errors(gt, pred):
    """Computation of error metrics between predicted and ground truth depths
    """
    thresh = np.maximum((gt / pred), (pred / gt))
    a1 = (thresh < 1.25     ).mean()
    a2 = (thresh < 1.25 ** 2).mean()
    a3 = (thresh < 1.25 ** 3).mean()

    rmse = (gt - pred) ** 2
    rmse = np.sqrt(rmse.mean())

    rmse_log = (np.log(gt) - np.log(pred)) ** 2
    rmse_log = np.sqrt(rmse_log.mean())

    abs_rel = np.mean(np.abs(gt - pred) / gt)

    sq_rel = np.mean(((gt - pred) ** 2) / gt)

    return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3


def batch_post_process_disparity(l_disp, r_disp):
    """Apply the disparity post-processing method as introduced in Monodepthv1
    """
    _, h, w = l_disp.shape
    m_disp = 0.5 * (l_disp + r_disp)
    l, _ = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h))
    l_mask = (1.0 - np.clip(20 * (l - 0.05), 0, 1))[None, ...]
    r_mask = l_mask[:, :, ::-1]
    return r_mask * l_disp + l_mask * r_disp + (1.0 - l_mask - r_mask) * m_disp


def adapt(opt):
    """Adapt a pretrained model to a specified test set
    """
    MIN_DEPTH = 1e-3
    MAX_DEPTH = 80

    device = torch.device(opt.device)

    dataset_name = opt.data_path.split('/')[-1]
    
    assert sum((opt.eval_mono, opt.eval_stereo)) == 1, \
        "Please choose mono or stereo evaluation by setting either --eval_mono or --eval_stereo"

    if opt.ext_disp_to_eval is None:

        opt.load_weights_folder = os.path.expanduser(opt.load_weights_folder)

        assert os.path.isdir(opt.load_weights_folder), \
            "Cannot find a folder at {}".format(opt.load_weights_folder)

        print("-> Loading weights from {}".format(opt.load_weights_folder))
        dataset = dataset_factory(dataset_name, opt.data_path, splits_dir, opt.eval_split, opt.height, opt.width)

        dataloader = DataLoader(dataset, opt.batch_size, shuffle=False, num_workers=6,
                        pin_memory=True, drop_last=True)
        
        model = model_factory(opt.model)
        if not opt.model in ['newcrfs', 'adabins']:
            encoder, depth_decoder = model(opt.load_weights_folder)

            encoder.to(device)
            encoder.eval()
            depth_decoder.to(device)
            depth_decoder.eval()
        else:
            model = model(opt.load_weights_folder)
            model.eval()
            model.to(device)
            if opt.model == 'newcrfs':
                encoder = model.backbone
            elif opt.model == 'adabins':
                encoder = model.encoder

        pan_size = opt.pan_size
        if opt.pan_size == 'base':
            pan_size = 'base-IN21k'
        pan_processor = AutoImageProcessor.from_pretrained(f"facebook/mask2former-swin-{pan_size}-cityscapes-panoptic")
        pan_model = Mask2FormerForUniversalSegmentation.from_pretrained(
            f"facebook/mask2former-swin-{pan_size}-cityscapes-panoptic")

        pan_model.to(device)
        pan_model.eval()

        pred_disps = []
        if opt.tta:
            adapt_params = sweep_params(encoder, opt)

            print('# of adapting layer: ', len(adapt_params))
            if opt.optim == 'adam':
                optim = torch.optim.Adam(adapt_params, lr=opt.learning_rate)
            elif opt.optim == 'sgd':
                optim = torch.optim.SGD(adapt_params, lr=opt.learning_rate, momentum=opt.momentum)
        print("-> Computing predictions with size {}x{}".format(opt.height, opt.width))

        timestemp = datetime.datetime.now().strftime('%Y-%m-%d %H: %M: %S')
        for idx, data in enumerate(tqdm.tqdm(dataloader)):
            input_color = data[("color", 0, 0)].to(device)            
            if opt.post_process:
                # Post-processed results require each image to have two forward passes
                input_color = torch.cat((input_color, torch.flip(input_color, [3])), 0)

            if opt.model in ['newcrfs', 'adabins']:
                if opt.model == 'newcrfs':
                    output = model(input_color)
                    output = {('disp', 0): output}
                else:
                    output = model(input_color)[-1]
                    output = {('disp', 0): output}


            elif opt.model in ['monodepth2', 'monovit', 'litemono', 'hrdepth']:
                output = depth_decoder(encoder(input_color))
                if isinstance(output, torch.Tensor):
                    output = {('disp', i): output[s] for i, s in enumerate(opt.scales)}

            else:
                output = depth_decoder(*(encoder(input_color)[0]))
                output = {('disp', i): output[s] for i, s in enumerate(opt.scales[::-1])}
            
            for s in opt.scales:
                output[('disp', s)] = F.interpolate(output[('disp', s)], [opt.height, opt.width], mode='bilinear', align_corners=False)
                output[('scaled_disp', s)], output[('depth', s)] = disp_to_depth(output[('disp', s)], opt.min_depth, opt.max_depth)
            if opt.tta:
                pan_out = pan_model(pixel_values=input_color)
                pan_color_processed = pan_processor.post_process_panoptic_segmentation(
                        pan_out, target_sizes=[input_color.size()[2:] for _ in range(input_color.size(0))], label_ids_to_fuse=[10])
                optim.zero_grad()
                total_loss = compute_loss(input_color, output, pan_color_processed, opt, idx)

                if total_loss != 0:
                    total_loss.backward()
                    if opt.grad_clip:
                        torch.nn.utils.clip_grad_norm(adapt_params, opt.max_grad_norm)
                    optim.step()
                    
            pred_disp = output[('scaled_disp', 0)]
            pred_disp = pred_disp.cpu()[:, 0].detach().numpy()

            if opt.post_process:
                N = pred_disp.shape[0] // 2
                pred_disp = batch_post_process_disparity(pred_disp[:N], pred_disp[N:, :, ::-1])

            pred_disps.append(pred_disp)

        pred_disps = np.concatenate(pred_disps)

    else:
        # Load predictions from file
        print("-> Loading predictions from {}".format(opt.ext_disp_to_eval))
        pred_disps = np.load(opt.ext_disp_to_eval)

        if opt.eval_eigen_to_benchmark:
            eigen_to_benchmark_ids = np.load(
                os.path.join(splits_dir, "benchmark", "eigen_to_benchmark_ids.npy"))

            pred_disps = pred_disps[eigen_to_benchmark_ids]

    if opt.no_eval:
        print("-> Evaluation disabled. Done.")
        quit()

    elif opt.eval_split == 'benchmark':
        save_dir = os.path.join(opt.load_weights_folder, "benchmark_predictions")
        print("-> Saving out benchmark predictions to {}".format(save_dir))
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        for idx in range(len(pred_disps)):
            disp_resized = cv2.resize(pred_disps[idx], (1216, 352))
            depth = STEREO_SCALE_FACTOR / disp_resized
            depth = np.clip(depth, 0, 80)
            depth = np.uint16(depth * 256)
            save_path = os.path.join(save_dir, "{:010d}.png".format(idx))
            cv2.imwrite(save_path, depth)

        print("-> No ground truth is available for the KITTI benchmark, so not evaluating. Done.")
        quit()

    elif opt.eval_split == "all_weather":
        gt_depths = []
        for sp in WEATHER_ORDER:
            gt_depth = np.load(os.path.join(splits_dir, sp, "gt_depths.npz"))["data"]
            gt_depths.extend([gt_depth[i] for i in range(gt_depth.shape[0])])
    elif opt.eval_split == "waymo_da":
        gt_depths = []
        tflist = './splits/waymo_da/tf_list.txt'
        with open(tflist, 'r') as f:
            filelist = f.readlines()

        data_list = []
        for file in filelist:
            data_list.append(*file.splitlines())

        dataset_list = []
        for weather in ['unknown_day', 'unknown_dusk']:
            t_file_dir = os.path.join(splits_dir, opt.eval_split, weather)
            for data in data_list:
                tar_folder = os.path.join(t_file_dir, data[:-9])
                if os.path.isdir(tar_folder):
                    gt_depth = np.load(os.path.join(tar_folder, 'gt_depths.npz'))['data']
                    gt_depths.extend([gt_depth[i] for i in range(gt_depth.shape[0])])
        
    else:
        gt_path = os.path.join(splits_dir, opt.eval_split, "gt_depths.npz")
        gt_depths = np.load(gt_path, fix_imports=True, encoding='latin1', allow_pickle=True)["data"]


    print("-> Evaluating")

    if opt.eval_stereo:
        print("   Stereo evaluation - "
              "disabling median scaling, scaling by {}".format(STEREO_SCALE_FACTOR))
        opt.disable_median_scaling = True
        opt.pred_depth_scale_factor = STEREO_SCALE_FACTOR
    else:
        print("   Mono evaluation - using median scaling")

    errors = []
    ratios = []

    ###  This scales for the last sequence  ###
    if opt.eval_last_seq:
        gt_depths = gt_depths[-500:]
        pred_disps = pred_disps[-500:]
    print(pred_disps.shape)
    ###########################################

    for i in tqdm.tqdm(range(pred_disps.shape[0])):
        gt_depth = gt_depths[i]
        gt_height, gt_width = gt_depth.shape[:2]

        pred_disp = pred_disps[i]
        pred_disp = cv2.resize(pred_disp, (gt_width, gt_height))
        if not opt.model in ['newcrfs', 'adabins']:
            pred_depth = 1 / pred_disp
        else:
            pred_depth = pred_disp


        if opt.eval_split == "eigen" or opt.eval_split == "eigen_zhou":
            mask = np.logical_and(gt_depth > MIN_DEPTH, gt_depth < MAX_DEPTH)

            crop = np.array([0.40810811 * gt_height, 0.99189189 * gt_height,
                                0.03594771 * gt_width,  0.96405229 * gt_width]).astype(np.int32)
            crop_mask = np.zeros(mask.shape)
            crop_mask[crop[0]:crop[1], crop[2]:crop[3]] = 1
            mask = np.logical_and(mask, crop_mask)
        elif opt.eval_split == 'cityscape':
            mask = np.logical_and(gt_depth > MIN_DEPTH, gt_depth < MAX_DEPTH)

        else:
            mask = gt_depth > 0


        pred_depth = pred_depth[mask]
        gt_depth = gt_depth[mask]
        pred_depth *= opt.pred_depth_scale_factor   # 1
        
        if not opt.disable_median_scaling:
            ratio = np.median(gt_depth) / np.median(pred_depth)
            ratios.append(ratio)
            pred_depth *= ratio

        pred_depth[pred_depth < MIN_DEPTH] = MIN_DEPTH
        pred_depth[pred_depth > MAX_DEPTH] = MAX_DEPTH

        errors.append(compute_errors(gt_depth, pred_depth))

    if opt.save_results:
        np.save(f"stats/continuous/{opt.model}_{opt.eval_split}_{opt.tta}_errors.npz", np.array(errors))

    if opt.save_disp:
        print(f'saving disp to: res/{opt.model}/pitta/{opt.eval_split}.npz')
        np.savez(f"res/{opt.model}/pitta/{opt.eval_split}.npz", pred_disps)
    if not opt.disable_median_scaling:
        ratios = np.array(ratios)
        med = np.median(ratios)
        print(" Scaling ratios | med: {:0.3f} | std: {:0.3f}".format(med, np.std(ratios / med)))
    mean_errors = np.array(errors).mean(0)
    print('EXP time: ', timestemp)
    print(f"\n Model: {opt.load_weights_folder}  - optim: {opt.optim} - lr: {opt.learning_rate} - {opt.eval_split} - TTA: {opt.tta} - Params: {opt.tta_params} - depth_loss: {opt.depth_loss} - edge_loss: {opt.edge_loss} - lambda: {opt.lamb} - filter size: {opt.filter_size}")
    print("\n  " + ("{:>8} | " * 7).format("abs_rel", "sq_rel", "rmse", "rmse_log", "a1", "a2", "a3"))
    print(("&{: 8.3f}  " * 7).format(*mean_errors.tolist()) + "\\\\")

    print("\n-> Done!")

if __name__ == "__main__":
    options = MonodepthOptions()
    opt = options.parse()
    adapt(opt)