import os
import sys
import re
import cv2
import json
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import pickle
import json
from src.navgym.models.CityNavData import CityNavData
from src.navgym.models.NavGym import NavGym
from src.navgym.agents.CityNavAgent import GPTAgent
from src.navgym.tools.EvalTools import eval_planning_metrics
from src.gsamllavanav.observation import cropclient
from src.gsamllavanav.mapdata import GROUND_LEVEL
from src.gsamllavanav.space import Pose4D, view_area_corners
from concurrent.futures import ThreadPoolExecutor, as_completed

#R1PhotoData

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
os.environ["http_proxy"] = ""
os.environ["https_proxy"] = ""
cropclient.load_image_cache()

# model config
API_CONFIG = {
    "api_key": "EMPTY",
    "api_base": "http://0.0.0.0:8888/v1",   #use your port
    "api_version": "2024-05-01-preview",
    "model": "qwen_2_5_vl_3b",
    "system_prompt": "You are an intelligent autonomous aerial vehicle (UAV) equipped for real-world navigation and visual target localization."
}
SAVE_PATH = "./eval/results"


def create_dir(file_path):
    dir_path = os.path.dirname(file_path)
    os.makedirs(dir_path, exist_ok=True)

def initialize_agent(navGym):
    return GPTAgent(
        api_key=API_CONFIG["api_key"],
        api_base=API_CONFIG["api_base"],
        api_version=API_CONFIG["api_version"],
        model=API_CONFIG["model"],
        system_prompt=API_CONFIG["system_prompt"],
        target_description=navGym.target_description,
        drone_see_shape=navGym.drone_view_shape,
        scale=navGym.px_real_size,
        top_left=navGym.top_left
    )

def parse_bbox(result_str, key="landmark_bbox"):
    pattern = fr'"{key}"\s*:\s*\[(\d+), (\d+), (\d+), (\d+)\]'
    match = re.search(pattern, result_str)
    return list(map(int, match.groups())) if match else [0, 0, 0, 0]

def parse_location(result_str):
    match = re.search(r'"target_location"\s*:\s*\[(\d+), (\d+)\]', result_str)
    return list(map(int, match.groups())) if match else [0, 0]

def visualize_prediction(navGym, source_path, landmark_box, target_pred, true_target, save_path):
    image = cv2.imread(source_path)

    for landmark in navGym.map.landmark_map.landmarks:
        top_left = navGym._get_px(landmark.bbox_corners[0])
        bottom_right = navGym._get_px(landmark.bbox_corners[2])
        cv2.rectangle(image, top_left, bottom_right, color=(255, 0, 255), thickness=2)

    cv2.rectangle(image, (landmark_box[0], landmark_box[1]), (landmark_box[2], landmark_box[3]), (0, 0, 255), 2)
    cv2.circle(image, tuple(target_pred), radius=30, color=(0, 255, 0), thickness=-1)
    cv2.circle(image, tuple(true_target), radius=30, color=(255, 0, 0), thickness=-1)

    create_dir(save_path)
    cv2.imwrite(save_path, image)

def compute_pose(navGym, predicted_px, true_start_px, map_name):
    if predicted_px == [0, 0]:
        return navGym.start_pose

    dx, dy = predicted_px[0] - true_start_px[0], predicted_px[1] - true_start_px[1]
    world_x = dx / 10 + navGym.episode.start_pose.x
    world_y = navGym.episode.start_pose.y - dy / 10
    base_pose = Pose4D(world_x, world_y, 66.05, 0)

    corners = view_area_corners(base_pose, GROUND_LEVEL[map_name])
    depth_img = cropclient.crop_image(map_name, base_pose, (100, 100), "depth")
    center_depth = depth_img[45:55, 45:55].mean()
    refined_pose = Pose4D(base_pose.x, base_pose.y, base_pose.z - center_depth + 5, 0)
    return refined_pose

from src.gsamllavanav.space import Point2D, Point3D, Pose4D
from src.gsamllavanav.teacher.algorithm.lookahead import lookahead_discrete_action
from src.gsamllavanav.teacher.trajectory import _moved_pose
def move(pose: Pose4D, dst: Pose4D, iterations: int):

    dst = Point3D(dst.x, dst.y, pose.z)
    trajectory = []
    for _ in range(iterations):
        action = lookahead_discrete_action(pose, [dst])
        if action.name == 'STOP':
            return trajectory
        pose = _moved_pose(pose, *action.value)
        trajectory.append(pose)
    return trajectory

def calculate_mean_metrics(results, nums):
    total_nums = nums['easy'] + nums['medium'] + nums['hard']
    NE = results['easy'].mean_final_pos_to_goal_dist * nums['easy']/total_nums + \
        results['medium'].mean_final_pos_to_goal_dist * nums['medium']/total_nums + \
        results['hard'].mean_final_pos_to_goal_dist * nums['hard']/total_nums

    SR = results['easy'].success_rate_final_pos_to_goal * nums['easy']/total_nums + \
        results['medium'].success_rate_final_pos_to_goal * nums['medium']/total_nums + \
        results['hard'].success_rate_final_pos_to_goal * nums['hard']/total_nums
        
    OSR = results['easy'].success_rate_oracle_pos_to_goal  * nums['easy']/total_nums + \
        results['medium'].success_rate_oracle_pos_to_goal  * nums['medium']/total_nums + \
        results['hard'].success_rate_oracle_pos_to_goal  * nums['hard']/total_nums

    SPL = results['easy'].success_rate_weighted_by_path_length  * nums['easy']/total_nums + \
        results['medium'].success_rate_weighted_by_path_length  * nums['medium']/total_nums + \
        results['hard'].success_rate_weighted_by_path_length  * nums['hard']/total_nums
    
    return NE, SR, OSR, SPL

def run_nav_gym(citynavData, split, step, action_num):
    trajectory = {}
    errors = []
    max_workers = 2

    def process_sample(i):
        try:
            pose_history = []
            cur_trajectory = []
            cur_citynavData = citynavData[i]
            for _ in range(step):
                if pose_history != []:
                    cur_citynavData.episode.teacher_trajectory[0] = pose_history[-1]
                navGym = NavGym(cur_citynavData)
                start_pose = navGym.start_pose
                
                agent = initialize_agent(navGym)
                map_name = navGym.episode.id[0]
                result_str = agent.act(
                    cur_whole_map=navGym.cur_whole_map,
                    cur_rgb_drone=navGym.cur_rgb_drone,
                    cur_position=navGym._get_px(start_pose)
                )

                landmark_bbox = parse_bbox(result_str, "landmark_bbox")
                target_pred_px = parse_location(result_str)
                true_start_px = navGym.px_trajectory[0]
                true_target_px = navGym.target_px
                

                save_path = f"{SAVE_PATH}/visualized_image/{os.path.basename(navGym.cur_whole_map)}"
                # visualize_prediction(navGym, navGym.cur_whole_map, landmark_bbox, target_pred_px, true_target_px, save_path)  #visualize the prediction result

                pred_pose = compute_pose(navGym, target_pred_px, true_start_px, map_name)
                
                if pose_history == []:
                    cur_trajectory = [start_pose]
                    move_trajectory = move(start_pose, pred_pose, action_num)
                    if len(move_trajectory) > 0:
                        pose_history.append(move_trajectory[-1])
                    cur_trajectory.extend(move_trajectory)
                else:
                    move_trajectory = move(start_pose, pred_pose, action_num)
                    
                    if len(move_trajectory) > 0:
                        pose_history.append(move_trajectory[-1])
                    cur_trajectory.extend(move_trajectory)
                
            
            trajectory[citynavData.episodes[i].id] = cur_trajectory
            return citynavData.episodes[i].id, cur_trajectory, None

        except Exception as e:
            print(f"[Error] Sample {i}: {e}")
            return None, None, i

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(process_sample, i) for i in range(0, len(citynavData))]

        for future in tqdm(as_completed(futures), total=len(futures), desc=f"Running {split}"):
            traj_id, cur_trajectory, err_idx = future.result()
            if err_idx is not None:
                errors.append(err_idx)
            else:
                trajectory[traj_id] = cur_trajectory

    return trajectory, errors, SAVE_PATH 



def main():
    results = {}
    nums = {}
    step_num = 2    #total steps that agent take
    action_num = 75     #actions per step
    for split in ["easy", "medium", "hard"]:
        data_path = f"dataset/test_data/citynav_test_unseen_{split}.json"
        citynavData = CityNavData(data_path)

        traj, errors, image_dir = run_nav_gym(citynavData, split, step_num, action_num)
        print(f"Image Dir: {image_dir}, Errors: {errors}")

        episodes = [ep for ep in citynavData.episodes if ep.id in traj]
        metrics = eval_planning_metrics(episodes, traj)
        print(f"{split} result:", metrics)
        results[split] = metrics
        nums[split] = len(episodes)
    
    NE, SR, OSR, SPL = calculate_mean_metrics(results, nums)
    print("Final Results:", results)
    print(f'NE:{NE}\nSR:{SR}\nOSR:{OSR}\nSPL:{SPL}')

    save_file = os.path.join(SAVE_PATH, "metrics.json")
    create_dir(save_file)

    final_results = {
        "per_split": {k: v._asdict() if hasattr(v, "_asdict") else vars(v) for k, v in results.items()},
        "nums": nums,
        "averaged": {
            "NE": NE,
            "SR": SR,
            "OSR": OSR,
            "SPL": SPL
        }
    }

    with open(save_file, "w") as f:
        json.dump(final_results, f, indent=4)

    print(f"[Saved] Evaluation results saved to {save_file}")

if __name__ == "__main__":
    main()
