﻿import os
import numpy as np
from tqdm import tqdm
import time
import json
from robotrust.envs.rb_habitat.RBHabEnv_robustd import RBHabEnv, ValidEvalSets
from robotrust.planner.vlm_planner import VLMPlanner
from robotrust.evaluator.summarize_result import average_json_values
from robotrust.evaluator.evaluator_utils import load_saved_data, update_config_with_args
from robotrust.evaluator.config.system_prompts import trust_habitat_multi_system_prompt,trust_habitat_single_action_system_prompt
from robotrust.main import logger

link_path = os.path.join(os.path.dirname(__file__), '../envs/rb_habitat/data')
try:
    os.symlink(link_path, 'data')
except FileExistsError:
    pass 


example_path = os.path.join(os.path.dirname(__file__), 'config/habitat_examples.json')
examples = json.load(open(example_path, 'r+'))



class EB_HabitatEvaluator():
    def __init__(self, config):
        self.model_name = config['model_name']
        self.eval_set = ValidEvalSets[0]
        self.config = config
        self.env = None
        self.planner = None

        self.dynamic_perturbation = config.get('dynamic_perturbation', False)
        self.perturbation_config_path = config.get('perturbation_config_path', 
                                                   'robotrust/configs/dynamic_perturbation_config.yaml')
        self.system_prompt = trust_habitat_single_action_system_prompt if self.config["action_step"] == "single" else trust_habitat_multi_system_prompt


    def check_config_valid(self):
        if self.config['multistep'] + self.config['chat_history'] > 1:
            raise ValueError("Only one of multistep, chat_history can be enabled at a time.")
        
        if self.config['language_only']:
            if self.config['multistep']:
                logger.warning("Language only mode should not have multistep enabled. Setting these arguments to False ...")
                self.config['multistep'] = 0
        
        
    def save_episode_metric(self, episode_info):
        filename = 'episode_{}_final_res.json'.format(self.env._current_episode_num)
        res_path = os.path.join(self.env.log_path, 'results')
        if not os.path.exists(res_path):
            os.makedirs(res_path)
        with open(os.path.join(res_path, filename), 'w', encoding='utf-8') as f:
            json.dump(episode_info, f, ensure_ascii=False)

    def evaluate_main(self):
        valid_eval_sets = self.config.get('eval_sets', ValidEvalSets)
        # valid_eval_sets = list(valid_eval_sets)
        if isinstance(valid_eval_sets, str):
           valid_eval_sets = [valid_eval_sets]
        else:
           valid_eval_sets = list(valid_eval_sets)
        if type(valid_eval_sets) == list and len(valid_eval_sets) == 0:
            valid_eval_sets = ValidEvalSets
            
        for eval_set in valid_eval_sets:
            if self.env is not None:
                self.env.close()
            self.eval_set = eval_set
            logger.info(f'Current eval set: {eval_set}')

            # exp_name = f"{self.model_name.split('/')[-1]}_{self.config['exp_name']}/{eval_set}" if len(self.config['exp_name']) else f"{self.model_name.split('/')[-1]}/{eval_set}"
            exp_suffix = '_dynamic' if self.dynamic_perturbation else ''
            exp_name = f"{self.model_name.split('/')[-1]}_{self.config['exp_name']}{exp_suffix}/{eval_set}" if len(self.config['exp_name']) else f"{self.model_name.split('/')[-1]}{exp_suffix}/{eval_set}"
        
            self.env = RBHabEnv(
                eval_set=self.eval_set, 
                down_sample_ratio=self.config['down_sample_ratio'], 
                exp_name=exp_name,
                start_epi_index=self.config.get('start_epi_index', 0), 
                resolution=self.config.get('resolution', 500),
                perturbation_type=self.config.get('perturbation_type', 'none'),
                dynamic_perturbation=self.dynamic_perturbation,
                perturbation_config_path=self.perturbation_config_path,
                dataset_name=self.config["dataset_name"],
                max_episode_steps=self.config.max_episode_steps
            )

            model_type = self.config.get('model_type', 'remote')
            self.planner = VLMPlanner(
                self.model_name, model_type, self.env.language_skill_set, 
                self.system_prompt, examples, n_shot=self.config['n_shots'], 
                obs_key='head_rgb', chat_history=self.config['chat_history'], 
                language_only=self.config['language_only'], 
                use_feedback=self.config.get('env_feedback', True), 
                multistep=self.config.get('multistep', 0), 
                tp=self.config.get('tp', 1)
            )

            self.evaluate()
            average_json_values(os.path.join(self.env.log_path, 'results'), output_file='summary.json')

    def evaluate(self):
        total_episodes_in_set = self.env.number_of_episodes
        limit = self.config.get('limit_episodes')
        if limit is not None and limit != 'null' and limit > 0:
            episodes_to_run = min(total_episodes_in_set, limit)
            logger.info(f"User requested to limit evaluation to {limit} episodes. Running the first {episodes_to_run} episodes.")
        else:
            episodes_to_run = total_episodes_in_set

        progress_bar = tqdm(total=episodes_to_run, desc="Episodes")
        
        # 娣诲姞鍔ㄦ€佹壈鍔ㄧ粺璁?        perturbation_stats = {
            'total_perturbations_triggered': 0,
            'successful_adaptations': 0
        }

        while self.env._current_episode_num < episodes_to_run:
            logger.info(f"Evaluating episode {self.env._current_episode_num} ...")
            episode_info = {'reward': [], 'num_invalid_actions': 0, 'empty_plan': 0}
            obs = self.env.reset()
            img_path = self.env.save_image(obs)
            user_instruction = self.env.episode_language_instruction
            print(f"Instruction: {user_instruction}")
            
            # 璁板綍鏄惁鏈夊姩鎬佹壈鍔ㄩ厤缃?            if self.dynamic_perturbation and self.env.trigger_action:
                print(f"Dynamic perturbation configured for this episode")
                episode_info['has_dynamic_perturbation'] = True
            else:
                episode_info['has_dynamic_perturbation'] = False

            self.planner.reset()
            done = False
            perturbation_handled = False
            
            while not done:
                try:
                    action, reasoning = self.planner.act(img_path, user_instruction)
                    print(f"Planner Output Action: {action}")

                    if action == -2: # empty plan
                        episode_info['empty_plan'] = 1
                        self.env.episode_log.append({
                            'last_action_success': 0.0,
                            'action_id': -2,
                            'action_description': 'empty plan',
                            'reasoning': reasoning,
                        })
                        info = {
                            'task_success': episode_info.get('task_success', 0),
                            'task_progress': episode_info.get("task_progress", 0),
                            'subgoal_reward': episode_info.get("subgoal_reward", 0),
                            'env_step': self.env._current_step,
                        }
                        break 
                    
                    if action == -1:
                        self.env._cur_invalid_actions += 1
                        episode_info['reward'].append(-1)
                        episode_info['num_invalid_actions'] += 1
                        self.env.episode_log.append({
                            'last_action_success': 0.0,
                            'action_id': -1,
                            'action_description': 'invalid action',
                            'reasoning': reasoning,
                        })
                        info = {
                            'task_success': episode_info.get('task_success', 0),
                            'task_progress': episode_info.get("task_progress", 0),
                            'subgoal_reward': episode_info.get("subgoal_reward", 0),
                            'env_step': self.env._current_step,
                        }
                        if self.env._cur_invalid_actions >= self.env._max_invalid_actions:
                            break
                        continue
                    
                    # 澶勭悊澶氫釜鍔ㄤ綔
                    if type(action) == list:
                        if any(x not in range(0, len(self.env.language_skill_set)) for x in action):
                            self.env._cur_invalid_actions += 1
                            episode_info['reward'].append(-1)
                            episode_info['num_invalid_actions'] += 1
                            self.env.episode_log.append({
                                'last_action_success': 0.0,
                                'action_id': -1,
                                'action_description': 'invalid action',
                                'reasoning': reasoning,
                            })
                            info = {
                                'task_success': episode_info.get('task_success', 0),
                                'task_progress': episode_info.get("task_progress", 0),
                                'subgoal_reward': episode_info.get("subgoal_reward", 0),
                                'env_step': self.env._current_step,
                            }
                            if self.env._cur_invalid_actions >= self.env._max_invalid_actions:
                                break
                            continue
                        for action_single in action[:min(self.env._max_episode_steps - self.env._current_step, len(action))]:
                            obs, reward, done, info = self.env.step(action_single, reasoning=reasoning)
                            action_str = action_single if type(action_single) == str else self.env.language_skill_set[action_single]
                            print(f"Executed action: {action_str}, Task success: {info['task_success']}")
                            
                            # 妫€鏌ユ槸鍚﹁Е鍙戜簡鍔ㄦ€佹壈鍔?                            if info.get('dynamic_perturbation_applied', False) and not perturbation_handled:
                                print(">>> Dynamic perturbation applied! Object position changed.")
                                perturbation_stats['total_perturbations_triggered'] += 1
                                episode_info['dynamic_perturbation_triggered'] = True
                                perturbation_handled = True
                            
                            logger.debug(f"reward: {reward}")
                            logger.debug(f"terminate: {done}\n")
                            
                            self.planner.update_info(info)
                            img_path = self.env.save_image(obs)
                            episode_info['reward'].append(reward)
                            episode_info['num_invalid_actions'] += (info['last_action_success'] == 0)
                            if done or info['last_action_success'] == 0:
                                print("Invalid action or task complete. If invalid then Replanning.")
                                break
                    else:
                        obs, reward, done, info = self.env.step(action, reasoning=reasoning)
                        action_str = action if type(action) == str else self.env.language_skill_set[action]
                        print(f"Executed action: {action_str}, Task success: {info['task_success']}")
                        
                        # 妫€鏌ユ槸鍚﹁Е鍙戜簡鍔ㄦ€佹壈鍔?                        if info.get('dynamic_perturbation_applied', False) and not perturbation_handled:
                            print(">>> Dynamic perturbation applied! Object position changed.")
                            perturbation_stats['total_perturbations_triggered'] += 1
                            episode_info['dynamic_perturbation_triggered'] = True
                            perturbation_handled = True
                        
                        logger.debug(f"reward: {reward}")
                        logger.debug(f"terminate: {done}\n")
                        
                        self.planner.update_info(info)
                        img_path = self.env.save_image(obs)
                        episode_info['reward'].append(reward)
                        episode_info['num_invalid_actions'] += (info['last_action_success'] == 0)
                
                except Exception as e:
                    print(e)
                    time.sleep(30)

            # 璇勪及鎸囨爣
            episode_info['instruction'] = user_instruction
            episode_info['reward'] = np.mean(episode_info['reward']) if episode_info['reward'] else 0
            episode_info['task_success'] = info['task_success']
            episode_info["task_progress"] = info['task_progress']
            episode_info['subgoal_reward'] = info.get('subgoal_reward', 0)
            episode_info['num_steps'] = info["env_step"]
            episode_info['planner_steps'] = self.planner.planner_steps
            episode_info['planner_output_error'] = self.planner.output_json_error
            episode_info["num_invalid_actions"] = episode_info['num_invalid_actions']
            episode_info["num_invalid_action_ratio"] = episode_info['num_invalid_actions'] / info["env_step"] if info['env_step'] > 0 else 0
            episode_info["episode_elapsed_seconds"] = info.get("episode_elapsed_seconds", time.time() - self.env._episode_start_time)
            
            # 鍔ㄦ€佹壈鍔ㄧ浉鍏虫寚鏍?            if perturbation_handled and info['task_success']:
                perturbation_stats['successful_adaptations'] += 1
                episode_info['adapted_to_perturbation'] = True
            elif perturbation_handled:
                episode_info['adapted_to_perturbation'] = False
            
            self.env.save_episode_log()
            self.save_episode_metric(episode_info)
            progress_bar.update()
        
        # 淇濆瓨鎵板姩缁熻
        if self.dynamic_perturbation:
            stats_file = os.path.join(self.env.log_path, 'results', 'perturbation_stats.json')
            os.makedirs(os.path.dirname(stats_file), exist_ok=True)
            with open(stats_file, 'w') as f:
                json.dump(perturbation_stats, f, indent=2)
            print(f"\nDynamic Perturbation Stats:")
            print(f"  Total perturbations triggered: {perturbation_stats['total_perturbations_triggered']}")
            print(f"  Successful adaptations: {perturbation_stats['successful_adaptations']}")


if __name__ == '__main__':
    import argparse
    def parse_arguments():
        parser = argparse.ArgumentParser(description='Change configuration parameters.')
        parser.add_argument('--model_name', type=str, help='Name of the model.')
        parser.add_argument('--n_shots', type=int, help='Number of examples')
        parser.add_argument('--down_sample_ratio', type=float, help='Down sample ratio.')
        parser.add_argument('--model_type', type=str, help='Type of the model.')
        parser.add_argument('--language_only', type=int, help='Set to True for language only mode.')
        parser.add_argument('--exp_name', type=str, help='Name of the experiment.')
        parser.add_argument('--chat_history', type=int, help='Set to True to enable chat history.')
        parser.add_argument('--eval_sets', type=lambda s: s.split(','), help='Comma-separated list of evaluation sets.')
        parser.add_argument('--start_epi_index', type=int, help='Starting episode index.')
        parser.add_argument('--multistep', type=int, help='Number of steps for multi-step reasoning.')
        parser.add_argument('--resolution', type=int, help='Resolution for processing.')
        parser.add_argument('--env_feedback', type=int, help='Set to True to enable environment feedback.')
        parser.add_argument('--tp', type=int, help='number of tensor parallel splits of the model parameters')
        parser.add_argument('--perturbation_type', type=str, default='none', 
                            choices=['none', 'noise', 'blur', 'occlusion', 'grayscale', 'low_res'],
                            help='Type of visual perturbation to apply to the input image.')
        parser.add_argument('--limit_episodes', type=int, help='Limit the number of episodes to evaluate for quick testing.')
        parser.add_argument('--dynamic_perturbation', type=int, default=0,
                          help='Enable dynamic perturbation (1) or disable (0)')
        parser.add_argument('--perturbation_config_path', type=str, 
                          default='config/dynamic_perturbation_config.yaml',
                          help='Path to dynamic perturbation configuration file')
        parser.add_argument("--dataset_name", type=str, help='datasets.yaml')
        parser.add_argument("--action_step", type=str, help="single / multi")
        parser.add_argument("--max_episode_steps", type=int, help="max episode steps")
        return parser.parse_args()

    config = {
        'model_name': 'gpt-4o-mini',  # 'Qwen/Qwen2-VL-7B-Instruct', 
        'n_shots': 10,
        'down_sample_ratio': 1.0, 
        'model_type': 'remote', # 'local'
        'language_only': 0,
        'exp_name': 'vlm_10shots_imgsize500',
        'chat_history': 0,  
        'start_epi_index': 0,
        'eval_sets': ['base', 'common_sense', 'complex_instruction', 'spatial_relationship',  'visual_appearance' , 'long_horizon', 'robust_word', 'single_robust_test', 'robust_robust_error','robust_robust_redu','robust_robust_sema','robust_robust_raw'],
        'multistep':0, 
        'resolution': 500, 
        'env_feedback': 1,
        'tp': 1,
        'perturbation_type': 'none',
        'limit_episodes': None,
        'dynamic_perturbation': 0,  # 鏂板锛氬姩鎬佹壈鍔ㄥ紑鍏?        'perturbation_config_path': 'robotrust/configs/dynamic_perturbation_config.yaml' ,
        "dataset_name": "dataset.yaml",
        "action_step": "single",
        "max_episode_steps": 20,
    }
    args = parse_arguments()
    update_config_with_args(config, args)

    evaluator = EB_HabitatEvaluator(config)
    evaluator.evaluate_main()

    try:
        os.unlink('data')
        print(f"The symbolic link {link_path} has been successfully removed.")
    except FileNotFoundError:
        print(f"Error: The symbolic link {link_path} does not exist.")
    except OSError as e:
        print(f"Error: {e}")





