import sys
import os
curr_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(f'{curr_dir}/..')
import pickle
import json
import random
import numpy as np
from pathlib import Path

from envs.unity_environment import UnityEnvironment
from agents import LLM_agent
from arguments import get_args
from algos.arena_mp2 import ArenaMP


def _as_pair(container, idx):
    if container is None:
        return []
    if isinstance(container, dict):
        return container.get(idx, [])
    if isinstance(container, list) and len(container) > idx:
        return container[idx]
    return []

def _load_log_dual(base_path_wo_ext):
    pik = base_path_wo_ext + ".pik"
    if os.path.isfile(pik):
        try:
            with open(pik, "rb") as f:
                return pickle.load(f), ".pik"
        except Exception:
            try:
                with open(pik, "r") as f:
                    return json.load(f), ".pik(JSON)"
            except Exception:
                pass
    js = base_path_wo_ext + ".json"
    if os.path.isfile(js):
        with open(js, "r") as f:
            return json.load(f), ".json"
    return None, None

def _save_log_dual(base_path_wo_ext, saved_info):
    pik = base_path_wo_ext + ".pik"
    try:
        with open(pik, "wb") as f:
            pickle.dump(saved_info, f)
        return pik
    except Exception:
        js = base_path_wo_ext + ".json"
        with open(js, "w") as f:
            json.dump(saved_info, f, indent=4, default=str)
        return js

def calculate_metrics_from_log(saved_info):

    success = bool(saved_info.get('success', saved_info.get('finished', False)))

    actions = saved_info.get('action', [[], []])
    a0, a1 = _as_pair(actions, 0), _as_pair(actions, 1)
    steps_total = max(len(a0), len(a1))

    def _is_msg(a):
        try:
            return (a is not None) and ("[send_message]" in a)
        except Exception:
            return False

    comm_cnt = sum(1 for a in a0 if _is_msg(a)) + sum(1 for a in a1 if _is_msg(a))
    communication_step = comm_cnt / 2.0
    action_step = steps_total - communication_step

    usage_total = 0.0
    usage_raw = saved_info.get('usage')
    if usage_raw is not None:
        it = usage_raw.values() if isinstance(usage_raw, dict) else usage_raw
        for v in it:
            try:
                usage_total += float(np.sum(v))
            except Exception:
                pass

    travel = 0.0
    pos = saved_info.get('agent_position', [[], []])
    p0, p1 = _as_pair(pos, 0), _as_pair(pos, 1)
    for traj in (p0, p1):
        for i in range(1, len(traj)):
            try:
                travel += float(np.linalg.norm(np.array(traj[i]) - np.array(traj[i - 1]), 2))
            except Exception:
                pass
    travel /= 2.0

    return {
        'success': success,
        'steps': steps_total,
        'communication_step': communication_step,
        'action_step': action_step,
        'usage': usage_total,
        'travel_distance': travel
    }

def atomic_write_json(path, obj):
    tmp = path + ".tmp"
    with open(tmp, "w") as f:
        json.dump(obj, f, indent=4)
    os.replace(tmp, path)



if __name__ == '__main__':
    args = get_args()
    env_task_set = pickle.load(open(args.dataset_path, 'rb'))

    args.record_dir = f'../test_results/{args.mode}'  # set the record_dir right!
    Path(args.record_dir).mkdir(parents=True, exist_ok=True)

    if "image" in args.obs_type:
        os.system("Xvfb :98 & export DISPLAY=:98")
        import time
        time.sleep(3)  # ensure Xvfb is open
        os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
        executable_args = {
            'file_name': args.executable_file,
            'x_display': '98',
            'no_graphics': False,
            'timeout_wait': 5000,
        }
    else:
        executable_args = {
            'file_name': args.executable_file,
            'no_graphics': True,
        }

    id_run = 0
    random.seed(id_run)
    episode_ids = list(range(len(env_task_set)))
    episode_ids = sorted(episode_ids)
    num_tries = args.num_runs
    S = [[] for _ in range(len(episode_ids))]
    L = [[] for _ in range(len(episode_ids))]

    def env_fn(env_id):
        return UnityEnvironment(
            num_agents=2,
            max_episode_length=args.max_episode_length,
            port_id=env_id,
            env_task_set=env_task_set,
            agent_goals=['LLM', 'LLM'],
            observation_types=[args.obs_type, args.obs_type],
            use_editor=args.use_editor,
            executable_args=executable_args,
            base_port=args.base_port
        )

    args_agent1 = {'agent_id': 1, 'char_index': 0, 'args': args}
    args_agent2 = {'agent_id': 2, 'char_index': 1, 'args': args}

    agents = [lambda x, y: LLM_agent(**args_agent1), lambda x, y: LLM_agent(**args_agent2)]
    arena = ArenaMP(args.max_episode_length, id_run, env_fn, agents, args.record_dir, args.debug)

    # copy the code below to record results
    if args.num_per_task != 10:
        test_episodes = args.test_task
    else:
        test_episodes = episode_ids

    final_output = {
        'steps': {},
        'communication_step': {},
        'action_step': {},
        'usage': {},
        'travel_distance': {},
        'average_steps': {},
        'average_communication_step': {},
        'average_action_step': {},
        'average_usage': {},
        'average_travel_distance': {},
        'failed_tasks': {}
    }

    for iter_id in range(num_tries):
        print(f"\n==== iter_id {iter_id} 시작 ====")
        steps_list, failed_tasks = [], []
        communication_step_list = []
        action_step_list = []
        usage_list = []
        travel_distance_list = []

        if not os.path.isfile(os.path.join(args.record_dir, 'results.pik')):
            test_results = {}
        else:
            test_results = pickle.load(open(os.path.join(args.record_dir, 'results.pik'), 'rb'))

        current_tried = iter_id

        for episode_id in test_episodes:
            task_id = env_task_set[episode_id]['task_id']
            task_name = env_task_set[episode_id]['task_name']
            base = os.path.join(args.record_dir, f'logs_agent_{task_id}_{task_name}_{iter_id}')

            saved_info, used_ext = _load_log_dual(base)
            if saved_info is not None:
                metrics = calculate_metrics_from_log(saved_info)

                is_finished = 1 if metrics['success'] else 0
                steps_for_L = metrics['steps']  
                S[episode_id].append(is_finished)
                L[episode_id].append(steps_for_L)
                test_results[episode_id] = {'S': S[episode_id], 'L': L[episode_id]}

                communication_step_list.append(metrics['communication_step'])
                action_step_list.append(metrics['action_step'])
                usage_list.append(metrics['usage'])
                travel_distance_list.append(metrics['travel_distance'])
                if metrics['success']:
                    steps_list.append(metrics['steps'])
                else:
                    failed_tasks.append(episode_id)

                print(f"episode {episode_id} (iter {iter_id}): 캐시 히트({used_ext}) → success={metrics['success']}, steps={metrics['steps']}")
                continue

            print('episode:', episode_id)
            for it_agent, agent in enumerate(arena.agents):
                agent.seed = it_agent + current_tried * 2

            # try:
            arena.reset(episode_id)
            success, steps, saved_info = arena.run()

            saved_info_to_dump = dict(saved_info)
            saved_info_to_dump['success'] = bool(success)
            m_post = calculate_metrics_from_log(saved_info_to_dump)
            saved_info_to_dump['steps'] = int(m_post['steps'])

            communication_step_list.append(m_post['communication_step'])
            action_step_list.append(m_post['action_step'])
            usage_list.append(m_post['usage'])
            travel_distance_list.append(m_post['travel_distance'])
            if m_post['success']:
                steps_list.append(m_post['steps'])
            else:
                failed_tasks.append(episode_id)

            print('-------------------------------------')
            print('success' if success else 'failure')
            print('steps (recalc):', m_post['steps'])
            print('action_steps:', m_post['action_step'])
            print('communication_step:', m_post['communication_step'])
            print('token_usage:', m_post['usage'])
            print('travel_distance:', m_post['travel_distance'])
            print('-------------------------------------')

            is_finished = 1 if m_post['success'] else 0

            dumped_path = _save_log_dual(base, saved_info_to_dump)
            S[episode_id].append(is_finished)
            L[episode_id].append(m_post['steps'])
            test_results[episode_id] = {'S': S[episode_id], 'L': L[episode_id]}
            pickle.dump(test_results, open(os.path.join(args.record_dir, 'results.pik'), 'wb'))

        def _mean(arr):
            return float(np.array(arr, dtype=float).mean()) if len(arr) else 0.0

        def _flatten_dict_of_lists(d):
            vals = []
            for v in d.values():
                vals.extend(list(v))
            return vals

        final_output['steps'][iter_id] = steps_list
        final_output['communication_step'][iter_id] = communication_step_list
        final_output['action_step'][iter_id] = action_step_list
        final_output['usage'][iter_id] = usage_list
        final_output['travel_distance'][iter_id] = travel_distance_list
        final_output['average_steps'][iter_id] = _mean(steps_list)
        final_output['average_communication_step'][iter_id] = _mean(communication_step_list)
        final_output['average_action_step'][iter_id] = _mean(action_step_list)
        final_output['average_usage'][iter_id] = _mean(usage_list)
        final_output['average_travel_distance'][iter_id] = _mean(travel_distance_list)
        final_output['failed_tasks'][iter_id] = failed_tasks

        all_steps              = _flatten_dict_of_lists(final_output['steps'])
        all_comm_steps         = _flatten_dict_of_lists(final_output['communication_step'])
        all_action_steps       = _flatten_dict_of_lists(final_output['action_step'])
        all_usage              = _flatten_dict_of_lists(final_output['usage'])
        all_travel             = _flatten_dict_of_lists(final_output['travel_distance'])

        overall_flat_means = {
            'steps': _mean(all_steps),
            'communication_step': _mean(all_comm_steps),
            'action_step': _mean(all_action_steps),
            'usage': _mean(all_usage),
            'travel_distance': _mean(all_travel),
        }

        iter_mean_of_means = {
            'steps': _mean(list(final_output['average_steps'].values())),
            'communication_step': _mean(list(final_output['average_communication_step'].values())),
            'action_step': _mean(list(final_output['average_action_step'].values())),
            'usage': _mean(list(final_output['average_usage'].values())),
            'travel_distance': _mean(list(final_output['average_travel_distance'].values())),
        }

        final_output['summary'] = {
            'overall_flat_means': overall_flat_means,
            'iter_mean_of_means': iter_mean_of_means,
            'num_iters_recorded': len(final_output['average_steps'])
        }

        print(f"\n==== iter_id {iter_id} result summary ====")
        if len(steps_list) > 0:
            print('Individual steps for each successful task:', steps_list)
            print('Average steps for successful tasks:', final_output['average_steps'][iter_id])
            print('Individual communication steps for each task:', communication_step_list)
            print('Average communication steps:', final_output['average_communication_step'][iter_id])
            print('Individual action steps for each task:', action_step_list)
            print('Average action steps:', final_output['average_action_step'][iter_id])
            print('Individual token_usage for each task:', usage_list)
            print('Average token_usage:', final_output['average_usage'][iter_id])
            print('Individual travel distance for each task:', travel_distance_list)
            print('Average travel distance:', final_output['average_travel_distance'][iter_id])
        else:
            print('No tasks were completed successfully in this iteration.')

        print('Failed tasks:', failed_tasks)

        atomic_write_json(os.path.join(args.record_dir, 'final_output.json'), final_output)

    atomic_write_json(os.path.join(args.record_dir, 'final_output.json'), final_output)
    print("\nAll experiments completed. final_output.json saved.")
