from PIL import Image, ImageDraw, ImageFont
import numpy as np
import pandas as pd
from tqdm import tqdm
import json
import os
import ast

from utils.view_retrieval import retrieve_object_images
from utils.vlm_check import analyze_image_for_object
from utils.vlm_check import analyze_view_for_object
from utils.dataset_pre import SR3DPlusPreprocess
from utils.llm_helper import llm_object_extractor
from utils.llm_helper import final_llm
from utils.evaluation_object_retrieval import iou_calc
from utils.view_retrieval import retrieve_object_views

import argparse

# ------------------------------
# Configurable defaults (centralized)
# kept the original literal defaults so behavior is unchanged unless user overrides via CLI

VLM_MODEL_NAME_DEFAULT = "qwen/qwen2.5-vl-32b-instruct"
CLIP_MODEL_NAME_DEFAULT = 'EVA02-L-14-336'
# ------------------------------


def draw_bounding_box(pil_img, box, color="red", width=8):
    draw = ImageDraw.Draw(pil_img)

    # Convert flat box to 4 corner points
    x_min, x_max, y_min, y_max = box
    box_coords = [(int(x_min), int(y_min)), 
                  (int(x_max), int(y_min)), 
                  (int(x_max), int(y_max)), 
                  (int(x_min), int(y_max))]

    # Draw the rectangle by connecting the corners
    draw.line(box_coords + [box_coords[0]], fill=color, width=width)

    return pil_img



def object_search(
    object_name,
    scene_name,
    vlm_model_name = None,
    vlm_api_key = None,
    output_dir = None,
    object_id=20,
    dataset_path=None,
    clip_model_name=None,
    clip_state_dict_path=None):
    # Use centralized defaults when None so existing calls that omit args behave the same
    if vlm_model_name is None:
        vlm_model_name = VLM_MODEL_NAME_DEFAULT
    if vlm_api_key is None:
        vlm_api_key = API_KEY
    if output_dir is None:
        output_dir = OUTPUT_DIR_V3
    if dataset_path is None:
        dataset_path = DATASET_ROOT
    if clip_model_name is None:
        clip_model_name = CLIP_MODEL_NAME_DEFAULT
    if clip_state_dict_path is None:
        clip_state_dict_path = CLIP_STATE_DICT_PATH

    clustered_points, images, boxes = retrieve_object_images(
        dataset_path=dataset_path, scene_name=scene_name, object_name=object_name,
        object_id=object_id, clip_model_name=clip_model_name, clip_state_dict_path=clip_state_dict_path
    )
    im_path = os.path.join(output_dir, scene_name)
    os.makedirs(im_path, exist_ok=True)
    
    vlm_results = []
    results = []


    for i, img in enumerate(images):
        # skip if box is invalid
        if not isinstance(boxes[i], (list, tuple)) or len(boxes[i]) != 4:
            print(f"Skipping image {i}: invalid bounding box")
            continue

        # convert to PIL
        if isinstance(img, np.ndarray):
            pil_img = Image.fromarray(img)
        else:
            pil_img = img

        # draw bounding box
        pil_img = draw_bounding_box(pil_img, boxes[i])

        # save image
        vr = analyze_image_for_object(pil_img, object_name, vlm_api_key, vlm_model_name)
        vlm_results.append(vr)
        filename = f"{im_path}/NOTimage_{i}_{object_name}.png"
        # pil_img.save(filename)
        if vr :
            filename = f"{im_path}/image_{i}_{object_name}.png"
            # pil_img.save(filename)
            # print(clustered_points[i].mean(axis=0))
            results.append(clustered_points[i])
        
    print(vlm_results)
    return results



def save_load_scene_object(scene_id, obj_name, saved_dataset):
    obj = saved_dataset.load_scene_object(obj_name)
    if obj is None:
        # call object_search without changing signature; object_search will use global defaults if needed
        obj = object_search(obj_name, scene_id)
        saved_dataset.save_scene_object(obj_name, obj)
    return obj




def process_scene_and_save(
    scene_id: str,
    dataset ,
    model_name: str,
    api_key: str,
):
    # Initialize dataset processor

    input_dicts = dataset.get_data_dicts()
    llm_outputs = []
    
    for data_dict in tqdm(input_dicts):
        inp = data_dict["utterance"]
        true_target = data_dict["target_id"]
        
        llm_out = llm_object_extractor(inp, model=model_name, api_key=api_key)
        llm_outputs.append(llm_out)

        # Save command output to CSV
        dataset.append_scene_command(inp, llm_out, true_target, data_dict['is_easy'], data_dict['is_view_dep'])

        if len(llm_out) <= 2:
            continue

        # Save main object image
        save_load_scene_object(scene_id, llm_out['main_object'], dataset)

        # Save related object images
        for related_obj in llm_out['related_objects']:
            save_load_scene_object(scene_id, related_obj, dataset)

    return llm_outputs



def make_grid_with_numbers(images, n_cols=2, padding=60):
    """
    images: list of PIL Images
    n_cols: number of columns
    padding: space between images
    """
    if not images:
        return None

    # Reference size
    w, h = images[0].size
    n_rows = (len(images) + n_cols - 1) // n_cols

    # Extra space on top for big numbers
    number_space = 120

    # Output image size
    out_w = n_cols * w + (n_cols + 1) * padding
    out_h = n_rows * h + (n_rows + 1) * padding + number_space
    out_img = Image.new("RGB", (out_w, out_h), "white")

    draw = ImageDraw.Draw(out_img)

    # Load a large font
    try:
        font = ImageFont.truetype("DejaVuSans-Bold.ttf", 70)
    except:
        font = ImageFont.load_default()

    for idx, img in enumerate(images):
        row, col = divmod(idx, n_cols)
        x = padding + col * (w + padding)
        y = padding + row * (h + padding) + number_space

        # Paste image
        out_img.paste(img, (x, y))

        # Draw thick blue rectangle around image
        draw.rectangle([x, y, x + w, y + h], outline="blue", width=8)

        # Draw number above image (centered, blue)
        num = str(idx + 1)
        bbox = draw.textbbox((0, 0), num, font=font)
        text_w, text_h = bbox[2] - bbox[0], bbox[3] - bbox[1]
        text_x = x + w // 2 - text_w // 2
        text_y = y - text_h - 10
        draw.text((text_x, text_y), num, fill="blue", font=font)  

    return out_img


def view_search(scene_name, obj_name, orientation, dataset_class ,output_dir = None):
    if output_dir is None:
        output_dir = OUTPUT_DIR_BASE
    im_path = os.path.join(output_dir, scene_name, 'views')
    os.makedirs(im_path, exist_ok=True)
    ds = dataset_class.load_scene_object(obj_name)
    results = []
    for d in ds:       
        images, boxes, angles = retrieve_object_views(scene_name, d)
        pil_imgs = []
        my_degs = []
        
        delta_deg = 0
        if len(images) >0:
            delta_deg = 360 / len(images)
        
        for i, img in enumerate(images):
            
            if not isinstance(boxes[i], (list, tuple)) or len(boxes[i]) != 4:
                print(f"Skipping image {i}: invalid bounding box")
                continue

            if isinstance(img, np.ndarray):
                pil_img = Image.fromarray(img)
            else:
                pil_img = img

            pil_img = draw_bounding_box(pil_img, boxes[i])
            filename = f"{im_path}/image_{i}_{obj_name}.png"
            # pil_img.save(filename)
            pil_imgs.append(pil_img)
            my_degs.append(delta_deg*i)
                       
            
        grid_with_numbers = make_grid_with_numbers(pil_imgs)
        if grid_with_numbers is not None:
            grid_with_numbers.save(f"{im_path}/image_grid_{obj_name}.png")
            vlm_results = analyze_view_for_object(grid_with_numbers, obj_name, orientation,
                                    api_key= API_KEY, model_name=VLM_MODEL_NAME_DEFAULT)

        else:
            vlm_results = None

        try:
            results.append(my_degs[vlm_results-1])
        except:
            results.append(-1)
            
    return results
    

def save_load_scene_view(scene_id, obj_name, orientation, dataset_class):
    view = dataset_class.load_object_view(obj_name, orientation)
    if view is None:
        view = view_search(scene_id, obj_name, orientation, dataset_class)
        dataset_class.save_object_view(obj_name, orientation, view)
    return view

 
def command_processor_llm(command, main_object, related_objects, dataset_class, orientation_importance = {}):          
    main_object_data = dataset_class.load_scene_object(main_object)
    
    if len(main_object_data) == 1:
        return main_object_data[0], 0
        
    if len(main_object_data) == 0:
        return None, None
    
    main_object_cents = []
    for d in main_object_data:
        main_object_cents.append(d.mean(axis = 0))
    
    rel = {}
    for r in related_objects:
        rel[r] = []
        ds = dataset_class.load_scene_object(r)
        for d in ds:
            rel[r].append(d.mean(axis = 0))
    
    
    ori_result = {}
    for k in orientation_importance.keys():
        ori_result[(k,orientation_importance[k])] = save_load_scene_view(dataset_class.scene_id, k, orientation_importance[k], dataset_class)
        
       
    llm_result = final_llm(main_object, main_object_cents, rel, command, orientation = ori_result,
                                       model="openai/o4-mini", api_key = API_KEY)
    
    try :
        return main_object_data[llm_result], llm_result
    except:
        return main_object_data[0], 0

    

def command_dataset_to_result(dataset, scene_name, output_base=None):
    if output_base is None:
        output_base = OUTPUT_DIR_V3
    output_path = os.path.join(output_base, scene_name, 'final.json')

    # try to load previous results
    if os.path.exists(output_path):
        with open(output_path, "r") as f:
            try:
                scene_result = json.load(f)

                if isinstance(scene_result, str):
                    try:
                        scene_result = ast.literal_eval(scene_result)
                    except Exception:
                        scene_result = {
                            'true_target':[], 
                            'selected_index':[], 
                            'iou':[],
                            'is_easy':[],
                            'is_view_dep':[]
                        }

            except json.JSONDecodeError:
                scene_result = {
                    'true_target':[], 
                    'selected_index':[], 
                    'iou':[],
                    'is_easy':[],
                    'is_view_dep':[]
                }
    else:
        scene_result = {
            'true_target':[], 
            'selected_index':[], 
            'iou':[],
            'is_easy':[],
            'is_view_dep':[]
        }

    # find resume index
    start_idx = len(scene_result['true_target'])

    commands_list = dataset.load_scene_commands()

    for command in tqdm(commands_list[start_idx:]):
        data, index = command_processor_llm(
            command['command'],
            command['main_object'],
            command['related_objects'],
            dataset,
            command['orientation_importance']
        )

        scene_result['true_target'].append(command['true_target'])
        scene_result['selected_index'].append((index, command['main_object']))
        scene_result['is_easy'].append(bool(command.get('is_easy', False)))
        scene_result['is_view_dep'].append(bool(command.get('is_view_dep', False)))

        temp_iou = 0
        if data is not None:
            temp_iou = iou_calc(data, scene_result['true_target'][-1], scene_name, DATASET_ROOT)

        scene_result['iou'].append(temp_iou)

        # update file incrementally
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        with open(output_path, "w") as f:
            json.dump(scene_result, f, indent=2)



def main():
    parser = argparse.ArgumentParser(description="Run SR3D scene processing")
    parser.add_argument("scene_id", type=str, help="Scene ID (e.g., scene0011_00)")
    parser.add_argument("--part1", action="store_true", help="Run the first part (process_scene_and_save)")
    parser.add_argument("--part2", action="store_true", help="Run the second part (final_part)")

    parser.add_argument("--api_key", type=str, default=None, help="API key to use (overrides default)")
    parser.add_argument("--output_base", type=str, default=None, help="Base output directory (overrides default)")
    parser.add_argument("--dataset_root", type=str, default=None, help="Root dataset path (overrides default)")
    parser.add_argument("--bbq_dataset", type=str, default=None, help="clip model name (overrides default)")
    
    parser.add_argument("--clip_state_dict_path", type=str, default=None, help="path to open_clip state dict")
    parser.add_argument("--vlm_model_name", type=str, default=None, help="VLM model name (overrides default)")
    parser.add_argument("--clip_model_name", type=str, default=None, help="clip model name (overrides default)")
    

    args = parser.parse_args()

    # Apply overrides to module-level config (so functions using globals get new values)
    global API_KEY, OUTPUT_DIR_BASE, OUTPUT_DIR_V3, DATASET_ROOT, CLIP_STATE_DICT_PATH, VLM_MODEL_NAME_DEFAULT, CLIP_MODEL_NAME_DEFAULT, BBQ_DATASET_PATH

    if args.api_key:
        API_KEY = args.api_key
    if args.output_base:
        OUTPUT_DIR_BASE = args.output_base
        OUTPUT_DIR_V3 = os.path.join(OUTPUT_DIR_BASE, 'v3')
    if args.dataset_root:
        DATASET_ROOT = args.dataset_root
    if args.clip_state_dict_path:
        CLIP_STATE_DICT_PATH = args.clip_state_dict_path
    if args.vlm_model_name:
        VLM_MODEL_NAME_DEFAULT = args.vlm_model_name
    if args.clip_model_name:
        CLIP_MODEL_NAME_DEFAULT = args.clip_model_name

    BBQ_DATASET_PATH = args.bbq_dataset
    
    sr3d_plus = SR3DPlusPreprocess(args.scene_id, BBQ_DATASET_PATH )

    process_scene_and_save(
        scene_id=args.scene_id,
        dataset=sr3d_plus,
        model_name="openai/gpt-5-mini",
        api_key=API_KEY
    )

    command_dataset_to_result(sr3d_plus, args.scene_id)

if __name__ == "__main__":
    main()
