import os
import glob
import argparse
from tqdm import tqdm
import sys
sys.path.append("/mnt/private-user-data/ed/Sparsedrivev11")
import torch
import cv2
import numpy as np
from PIL import Image

import mmcv
from mmcv import Config
from mmdet.datasets import build_dataset

from tools.visualization.bev_render import BEVRender
from tools.visualization.cam_render import CamRender
from projects.mmdet3d_plugin.datasets.utils import draw_lidar_bbox3d
import matplotlib.pyplot as plt
# from bev_render import BEVRender
# from cam_render import CamRender

# plot_choices = dict(
#     draw_pred = True, # True: draw gt and pred; False: only draw gt
#     det = True,
#     track = True, # True: draw history tracked boxes
#     motion = True,
#     map = True,
#     planning = True,
# )
START = 0
END = 81
INTERVAL = 1


class Visualizer:
    def __init__(
        self,
        # args,
        # plot_choices,
    ):
        args = None
        self.out_dir = "vis_dir"
        self.combine_dir = os.path.join(self.out_dir, 'combine')
        os.makedirs(self.combine_dir, exist_ok=True)
        
        # cfg = Config.fromfile(args.config)
        # self.dataset = build_dataset(cfg.data.val)
        # with torch.cuda.device(1):
        #     self.results = mmcv.load(args.result_path)
            ### self.results = torch.load(args.result_path, map_location={'cuda:2':'cuda:0', 'cuda:0':'cuda:0', 'cuda:1':'cuda:0', 'cuda:3':'cuda:0'})
        plot_choices = dict(
            draw_pred = True, # True: draw gt and pred; False: only draw gt
            det = True,
            track = True, # True: draw history tracked boxes
            motion = True,
            map = True,
            planning = True,
        )

        self.bev_render = BEVRender(plot_choices, self.out_dir)
        self.cam_render = CamRender(plot_choices, self.out_dir)
        self.img_norm_mean = np.array([123.675, 116.28, 103.53])
        self.img_norm_std = np.array([58.395, 57.12, 57.375])
        self.iter = 0

    def add_vis(self, img, data):
        # data = self.dataset.get_data_info(index)
        # result = self.results[index]['img_bbox']
        result = None
        # import pdb; pdb.set_trace()
        raw_imgs = img[0].permute(0, 2, 3, 1).cpu().numpy()
        raw_imgs = raw_imgs * self.img_norm_std + self.img_norm_mean
        gt_bbox3d = data['gt_bboxes_3d'][0]
        num_det = data['gt_bboxes_3d'][0].shape[0]
        img = draw_lidar_bbox3d(
            # torch.cat([pred_bbox3d[mask], anchor[mask]]),
            gt_bbox3d,
            raw_imgs, data["projection_mat"][0],
            color=[(0, 255, 0)] * num_det + [(255, 0, 0)] * num_det
        )
        # print(img.shape)
        
        plt.imshow(img)  # The green boxes denotes model detections, and red ones are the corresponding anchors.
        # plt.show()
        name_ = ""
        if 'next' in data.keys():
            name_ = "nuscenes " + str(data["timestamp"].cpu().numpy()[0])
        else:
            # import pdb; pdb.set_trace()
            name_ = "carla " + str(data["timestamp"].cpu().numpy()[0])
        print(" time stamp =", name_)
        plt.text(100, 100, name_, fontsize=12, color = "r", style = "italic", weight = "light", verticalalignment='center')
        plt.savefig(f"./vis_gt/boxes_frame_{self.iter}.png", dpi=700)
        os.makedirs("./vis_gt", exist_ok=True)
        self.iter +=1
        plt.clf()
        



        # bev_gt_path, bev_pred_path = self.bev_render.render(data, result, index)
        # cam_pred_path = self.cam_render.render(data, result, index)
        # self.combine(bev_gt_path, bev_pred_path, cam_pred_path, index)
    
    def combine(self, bev_gt_path, bev_pred_path, cam_pred_path, index):
        bev_gt = cv2.imread(bev_gt_path)
        bev_image = cv2.imread(bev_pred_path)
        cam_image = cv2.imread(cam_pred_path)
        merge_image = cv2.hconcat([cam_image, bev_image, bev_gt])
        save_path = os.path.join(self.combine_dir, str(index).zfill(4) + '.jpg')
        cv2.imwrite(save_path, merge_image)

    def image2video(self, fps=12, downsample=4):
        imgs_path = glob.glob(os.path.join(self.combine_dir, '*.jpg'))
        imgs_path = sorted(imgs_path)
        img_array = []
        for img_path in tqdm(imgs_path):
            img = cv2.imread(img_path)
            height, width, channel = img.shape
            img = cv2.resize(img, (width//downsample, height //
                             downsample), interpolation=cv2.INTER_AREA)
            height, width, channel = img.shape
            size = (width, height)
            img_array.append(img)
        out_path = os.path.join(self.out_dir, 'video.mp4')
        out = cv2.VideoWriter(
            out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, size)
        for i in range(len(img_array)):
            out.write(img_array[i])
        out.release()


def parse_args():
    parser = argparse.ArgumentParser(
        description='Visualize groundtruth and results')
    parser.add_argument('config', help='config file path')
    parser.add_argument('--result-path', 
        default=None,
        help='prediction result to visualize'
        'If submission file is not provided, only gt will be visualized')
    parser.add_argument(
        '--out-dir', 
        default='vis',
        help='directory where visualize results will be saved')
    args = parser.parse_args()

    return args

def main():
    args = parse_args()
    visualizer = Visualizer(args, plot_choices)

    for idx in tqdm(range(START, END, INTERVAL)):
        # if idx > len(visualizer.results):
        #     break
        visualizer.add_vis(idx)
    
    visualizer.image2video()

if __name__ == '__main__':
    main()


    #python tools/visualization/visualize1.py  projects/configs/ATDRIVE_small_stage1_mix.py  --result-path /mnt/private-user-data/ed/Sparsedrivev9/work_dirs/sparsedrive_small_stage1_mix/results.pkl