import torch
from PIL import Image,ImageOps
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from transformers import AutoImageProcessor, AutoModelForDepthEstimation, AutoModel, AutoTokenizer
from datasets import load_from_disk
import os
import matplotlib.pyplot as plt
from transformers import  AutoTokenizer, AutoProcessor
from nuscenes.nuscenes import NuScenes
from tqdm import tqdm
import json
from utils.cam_info import *
from utils.common import *
import argparse

parser=argparse.ArgumentParser()
parser.add_argument("--model_path",type=str)
parser.add_argument("--depth_path",type=str)
parser.add_argument("--rec_path",type=str)
parser.add_argument("--model_name",type=str)
parser.add_argument("--source_dataset",type=str)
args=parser.parse_args()

torch.manual_seed(42)
do_sample=False

path = args.model_path 
model_name = args.model_name 


if args.source_dataset == "nuscenes":
    nusc = NuScenes(version='v1.0-trainval', dataroot='Ego3D/source_datasets/nuscenes', verbose=True)
    data_dir="Ego3D/source_datasets/nuscenes"
    dataset=load_from_disk('Ego3D/benchmark/nuscenes')
    save_folder = f'Ego3D/eval_vlms/logs/{model_name}_{args.source_dataset}'
    
    CAMERAS = [
        'CAM_FRONT_LEFT', 'CAM_FRONT','CAM_FRONT_RIGHT',
        'CAM_BACK_RIGHT','CAM_BACK', 'CAM_BACK_LEFT', 
    ]

    CAMERAS_dic = {
        'CAM_FRONT_LEFT':'front left view', 'CAM_FRONT':'front view','CAM_FRONT_RIGHT':'front right view',
        'CAM_BACK_RIGHT':'back right view','CAM_BACK':'back view', 'CAM_BACK_LEFT':'back left view', 
    }

    INPUT_TEXT_TEMPLATE = "These are six camera views mounted on an ego car\n\nFront Left view: <image>\nFront view: <image>\nFront Right view: <image>\nBack Right view: <image>\nBack view: <image>\nBack Left view: <image> \n{cog_map}\n{question}"
    
    cams = camera_info(args.source_dataset)
elif args.source_dataset == "waymo":
    import tensorflow as tf
    from waymo_open_dataset import dataset_pb2 as open_dataset
    import io

    data_dir='Ego3D/source_datasets/waymo'
    dataset=load_from_disk('Ego3D/benchmark/waymo')
    save_folder = f'Ego3D/eval_vlms/logs/{model_name}_{args.source_dataset}'
    INPUT_TEXT_TEMPLATE = "These are five camera views mounted on an ego car\n\nFront view: <image>\nFront Left view: <image>\nLeft view: <image>\nFront Right view: <image>\nRight view: <image>\n{cog_map}\n{question}"

    cams = camera_info(args.source_dataset)
    
elif args.source_dataset == "argoverse":
    from argoverse.data_loading.argoverse_tracking_loader import ArgoverseTrackingLoader
    data_dir= "Ego3D/source_datasets/argoverse"
    dataset=load_from_disk('Ego3D/benchmark/argoverse')
    save_folder = f'Ego3D/eval_vlms/logs/{model_name}_{args.source_dataset}'
    
    CAMERAS = [
    'ring_front_left', 'ring_front_center','ring_front_right',
    'ring_side_right', 'ring_rear_right', 'ring_rear_left', 'ring_side_left',
    ]

    CAMERAS_dic = {
    'ring_front_left':'front left view', 'ring_front_center':'front view','ring_front_right':'front right view',
    'ring_side_right':'right view', 'ring_rear_right': 'back right view', 'ring_rear_left': 'back left view', 'ring_side_left':'left view',
    }
    argoverse_loader = ArgoverseTrackingLoader(data_dir)

    INPUT_TEXT_TEMPLATE = "These are seven camera views mounted on an ego car\n\nFront Left view: <image>\nFront view: <image>\nFront Right view: <image>\nRight view: <image>\nBack Right view: <image>\nBack Left view: <image>\nLeft view: <image>\n{cog_map}\n{question}"
    cams = camera_info(args.source_dataset)
else:
    raise NotImplementedError("Dataset name is wrong!")

device = "cuda"

device_map = split_model(path)
model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    load_in_8bit=False,
    low_cpu_mem_usage=True,
    use_flash_attn=True,
    trust_remote_code=True,
    device_map=device_map).eval()
    
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
generation_config = dict(max_new_tokens=1024, do_sample=do_sample)  

processor_depth = AutoImageProcessor.from_pretrained(args.depth_path)
model_depth = AutoModelForDepthEstimation.from_pretrained(args.depth_path).to(device)

processor_gdino = AutoProcessor.from_pretrained(args.rec_path)
model_gdino = AutoModelForZeroShotObjectDetection.from_pretrained(args.rec_path).to(device)

objects=[]

if not os.path.exists(save_folder):
    os.makedirs(save_folder)

for sub_category in dataset.keys():
    for sample in tqdm(dataset[sub_category]):

        cog_map="\nEgo Car is at 3D location [x,y,z]=[0, 0, 0]." 
        cog_map+="Positive x means right and negative x mean left to the ego car."
        cog_map+="Positive z means infront of the ego car and negative z means behind the ego car"
        
        image_paths=[]
        if args.source_dataset=='nuscenes':
            my_sample = nusc.get('sample', sample['sample_token'])
            images=[]
            cams_RT=[]
            cams_K=[]
            view_names=['Front Left View','Front View','Front Right View','Back Right View','Back View','Back Left View']
            image_names=[]
            image_paths=[]
            for idx in range(len(CAMERAS)):
                CAM=CAMERAS[idx]
                cam_data = nusc.get('sample_data', my_sample['data'][CAM])
                img_path = os.path.join(data_dir, cam_data['filename'])
                image_paths.append(img_path)
                image = Image.open(img_path)
                images.append(image)
                cam_K = torch.tensor(cams[idx]['intrinsic'])
                cam  = torch.eye(4)
                cam[:3,:3] = torch.tensor(cams[idx]['rotation'])
                cam[:3,3] = torch.tensor(cams[idx]['translation'])
                cams_RT.append(cam)
                cams_K.append(cam_K)  
            
            target_sizes=[image.size[::-1] for idx in range(len(images))]
            scale= 1/1.5
            
        elif args.source_dataset == 'waymo':
            images=[]
            cams_RT=[]
            cams_K=[]
            view_names=['Front View','Front Left View','Left View','Front Right View','Right View']
            image_names=[]
            image_paths=[]
            sample_tf_data = tf.data.TFRecordDataset(data_dir+sample["file_name"], compression_type='')
            for count,data in enumerate(sample_tf_data):
                frame = open_dataset.Frame()
                frame.ParseFromString(bytes(data.numpy()))
                if count+1==sample["frame_number"]:
                    break
            
            # Iterate through all camera images in the frame
            for idx,camera_image in enumerate(frame.images):
                # Decode JPEG image
                image = Image.open(io.BytesIO(camera_image.image))
                image_paths.append(image)
                images.append(image)
                # print(open_dataset.CameraName.Name.Name(camera_image.name) , image)
                cam_K = torch.tensor(cams[idx]['intrinsic'])
                cam  = torch.eye(4)
                cam[:3,:3] = torch.tensor(cams[idx]['rotation'])
                cam[:3,3] = torch.tensor(cams[idx]['translation'])
                cams_RT.append(cam)
                cams_K.append(cam_K)  


            # Find max width and height
            max_width = max(img.width for img in images)
            max_height = max(img.height for img in images)

            # Pad each image
            padded_images = []
            for idx,img in enumerate(images):
                delta_w = max_width - img.width
                delta_h = max_height - img.height
                # Compute padding: (left, top, right, bottom)
                padding = (
                    delta_w // 2,
                    delta_h // 2,
                    delta_w - (delta_w // 2),
                    delta_h - (delta_h // 2)
                )

                padded_img = ImageOps.expand(img, padding, fill=0)  # fill=0 for black
                padded_images.append(padded_img)

            image_paths=padded_images
            images=padded_images
            target_sizes=[padded_img.size[::-1] for idx in range(len(images))]
            scale=0.8
            
        elif args.source_dataset == "argoverse":
            images=[]
            cams_RT=[]
            cams_K=[]
            view_names =[
                        'Front Left View', 'Front View','Front Right View',
                        'Right View', 'Back Right View', 'Back Left View', 'Left View',
                        ]
            image_names=[]
            image_paths=[]
            log = argoverse_loader.get(sample["log_id"])
            for idx,cam in enumerate(CAMERAS):
                img_path = log.get_image_list_sync(cam)[sample["timestamp_id"]]
                image_paths.append(img_path)      
                image = Image.open(img_path)
                images.append(image)
                cam_K = torch.tensor(cams[idx]['intrinsic'])
                cam  = torch.eye(4)
                cam[:3,:3] = torch.tensor(cams[idx]['rotation'])
                cam[:3,3] = torch.tensor(cams[idx]['translation'])
                cams_RT.append(cam)
                cams_K.append(cam_K)  
            
            target_sizes=[image.size[::-1] for idx in range(len(images))]
            scale=0.64                   
        
        cams_K=torch.stack(cams_K).to(device)
        cams_RT=torch.stack(cams_RT).to(device)        
        question=create_question(sample,sub_category)


        question_trimed=sample["question"].split('.')
        question_trimed=[qs.replace('?','') for qs in question_trimed]
        print(question_trimed)
        text_labels=[question_trimed]*len(images)

        inputs_gdino = processor_gdino(images=images, text=text_labels, return_tensors="pt").to(device)

        inputs_depth = processor_depth(images=images, return_tensors="pt").to(device)

        with torch.no_grad():
            outputs_gdino = model_gdino(**inputs_gdino)
            outputs_depth = model_depth(**inputs_depth).predicted_depth
        
        
        predictions_gdino = processor_gdino.post_process_grounded_object_detection(
            outputs_gdino,
            inputs_gdino.input_ids,
            box_threshold=0.4,
            text_threshold=0.3,
            target_sizes=target_sizes
        )

        # interpolate to original size
        prediction_depth = torch.nn.functional.interpolate(
            outputs_depth.unsqueeze(1),
            size=target_sizes[0], #image.size[::-1],
            mode="bicubic",
            align_corners=False,
        ).squeeze()*scale
        
        world_coords = unproject(cams_K.float(), cams_RT.float(), prediction_depth.float()).cpu()    # (V, H, W, 3)
        bboxes_list = []
        for idx in range(len(predictions_gdino)):
            prediction_gdino=predictions_gdino[idx]
            bboxes_list_img = []
            for box, score, labels in zip(prediction_gdino["boxes"], prediction_gdino["scores"], prediction_gdino["labels"]):
                if 'ego' in labels:
                    continue
                box = [round(x, 2) for x in box.tolist()]
                center_box=[int((box[0]+box[2])/2), int((box[1]+box[3])/2)]
              
                
                cog_map+=f"\n{view_names[idx]}: detected {labels} at 3D location {world_coords[idx,center_box[1],center_box[0]].round()}"
                bboxes_list_img.append(box)

            bboxes_list.append(bboxes_list_img)

        input_text = INPUT_TEXT_TEMPLATE.format(question=question,cog_map=str(cog_map))
        pixel_values, num_patches_list = prepare_images_internvl(image_paths)
        response = model.chat(tokenizer, pixel_values, input_text, generation_config,
                            num_patches_list=num_patches_list,)

        with open(f'{save_folder}/{model_name}-{sub_category}.jsonl',"a") as file_out:
            file_out.write(json.dumps({'Question': question,'Ped':response,'GT':sample['answer']})+"\n")
        


     