import sys
sys.path.append('./') 
sys.path.insert(0, './policy/Diffusion-Policy') 
sys.path.insert(0, './policy/3D-Diffusion-Policy/3D-Diffusion-Policy')
import torch  
import os
import numpy as np
import hydra
from pathlib import Path
from collections import deque
import traceback

import yaml
from datetime import datetime
import importlib
import dill
from argparse import ArgumentParser
import pathlib
from diffusion_policy.workspace.robotworkspace import RobotWorkspace
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.env_runner.dp_runner import DPRunner

from dp3_policy import *
from envs import *

current_file_path = os.path.abspath(__file__)
parent_directory = os.path.dirname(current_file_path)

def get_policy(checkpoint, output_dir, device):
    
    # load checkpoint
    payload = torch.load(open('./policy/Diffusion-Policy/'+checkpoint, 'rb'), pickle_module=dill)
    cfg = payload['cfg']
    cls = hydra.utils.get_class(cfg._target_)
    workspace = cls(cfg, output_dir=output_dir)
    workspace: RobotWorkspace
    workspace.load_payload(payload, exclude_keys=None, include_keys=None)
    
    # get policy from workspace
    policy = workspace.model
    if cfg.training.use_ema:
        policy = workspace.ema_model
    
    device = torch.device(device)
    policy.to(device)
    policy.eval()

    return policy

class DP:
    def __init__(self, task_name, head_camera_type: str, checkpoint_num: int, expert_data_num: int, seed: int):
        self.policy = get_policy(f'checkpoints/{task_name}_{head_camera_type}_{expert_data_num}_{seed}_rgb_only/{checkpoint_num}.ckpt', None, 'cuda:0')
        self.runner = DPRunner(output_dir=None)

    def update_obs(self, observation):
        self.runner.update_obs(observation)
    
    def get_action(self, observation=None):
        action = self.runner.get_action(self.policy, observation)
        return action

    def get_last_obs(self):
        return self.runner.obs[-1]

    def prepare_data(self, observation=None):
        infer_data = self.runner.prepare_data(self.policy, observation)
        return infer_data


def class_decorator(task_name):
    envs_module = importlib.import_module(f'envs.{task_name}')
    try:
        env_class = getattr(envs_module, task_name)
        env_instance = env_class()
    except:
        raise SystemExit("No Task")
    return env_instance

def test_composed_policy(task_name, Demo_class, args, dp: DP, dp3, st_seed, test_num=20):
    expert_check = True
    print("Task name: ", args["task_name"])


    Demo_class.suc = 0
    Demo_class.test_num =0

    now_id = 0
    succ_seed = 0
    suc_test_seed_list = []
    

    now_seed = st_seed
    while succ_seed < test_num:
        render_freq = args['render_freq']
        args['render_freq'] = 0
        
        if expert_check:
            try:
                Demo_class.setup_demo(now_ep_num=now_id, seed = now_seed, is_test = True, ** args)
                Demo_class.play_once()
                Demo_class.close()
            except Exception as e:
                stack_trace = traceback.format_exc()
                print(' -------------')
                print('Error: ', stack_trace)
                print(' -------------')
                Demo_class.close()
                now_seed += 1
                args['render_freq'] = render_freq
                print('error occurs !')
                continue

        if (not expert_check) or ( Demo_class.plan_success and Demo_class.check_success() ):
            succ_seed +=1
            suc_test_seed_list.append(now_seed)
        else:
            now_seed += 1
            args['render_freq'] = render_freq
            continue


        args['render_freq'] = render_freq

        Demo_class.setup_demo(now_ep_num=now_id, seed = now_seed, is_test = True, ** args)
        # Demo_class.apply_dp(dp, args)
        Demo_class.apply_composed_policy(dp, dp3, args)

        now_id += 1
        Demo_class.close()
        if Demo_class.render_freq:
            Demo_class.viewer.close()
        dp.runner.reset_obs()
        dp3.env_runner.reset_obs()   ## must reset!
        print(f"{task_name} success rate: {Demo_class.suc}/{Demo_class.test_num}, current seed: {now_seed}\n")
        Demo_class._take_picture()
        now_seed += 1

    return now_seed, Demo_class.suc

def get_camera_config(camera_type):
    camera_config_path = os.path.join(parent_directory, '../task_config/_camera_config.yml')

    assert os.path.isfile(camera_config_path), "task config file is missing"

    with open(camera_config_path, 'r', encoding='utf-8') as f:
        args = yaml.load(f.read(), Loader=yaml.FullLoader)

    assert camera_type in args, f'camera {camera_type} is not defined'
    return args[camera_type]

TASK = None

@hydra.main(
    version_base=None,
    config_path=str(pathlib.Path(__file__).parent.joinpath(
        '../../policy/3D-Diffusion-Policy/3D-Diffusion-Policy/diffusion_policy_3d', 'config'))
)
def main(cfg):
    global TASK
    TASK = cfg.task.name
    print('Task name:', TASK)
    print('Loading Compositional Policy')
    checkpoint_num = cfg.checkpoint_num
    expert_data_num = cfg.expert_data_num
    seed = cfg.training.seed
    head_camera_type = cfg.head_camera_type
    task_name = cfg.task.name

    with open(f'./task_config/{task_name}.yml', 'r', encoding='utf-8') as f:
        args = yaml.load(f.read(), Loader=yaml.FullLoader)
    
    args['head_camera_type'] = head_camera_type 
    head_camera_config = get_camera_config(args['head_camera_type'])
    args['head_camera_fovy'] = head_camera_config['fovy']
    args['head_camera_w'] = head_camera_config['w']
    args['head_camera_h'] = head_camera_config['h']
    head_camera_config = 'fovy' + str(args['head_camera_fovy']) + '_w' + str(args['head_camera_w']) + '_h' + str(args['head_camera_h'])
    
    wrist_camera_config = get_camera_config(args['wrist_camera_type'])
    args['wrist_camera_fovy'] = wrist_camera_config['fovy']
    args['wrist_camera_w'] = wrist_camera_config['w']
    args['wrist_camera_h'] = wrist_camera_config['h']
    wrist_camera_config = 'fovy' + str(args['wrist_camera_fovy']) + '_w' + str(args['wrist_camera_w']) + '_h' + str(args['wrist_camera_h'])

    front_camera_config = get_camera_config(args['front_camera_type'])
    args['front_camera_fovy'] = front_camera_config['fovy']
    args['front_camera_w'] = front_camera_config['w']
    args['front_camera_h'] = front_camera_config['h']
    front_camera_config = 'fovy' + str(args['front_camera_fovy']) + '_w' + str(args['front_camera_w']) + '_h' + str(args['front_camera_h'])

    # output camera config
    print('============= Camera Config =============\n')
    print('Head Camera Config:\n    type: '+ str(args['head_camera_type']) + '\n    fovy: ' + str(args['head_camera_fovy']) + '\n    camera_w: ' + str(args['head_camera_w']) + '\n    camera_h: ' + str(args['head_camera_h']))
    print('Wrist Camera Config:\n    type: '+ str(args['wrist_camera_type']) + '\n    fovy: ' + str(args['wrist_camera_fovy']) + '\n    camera_w: ' + str(args['wrist_camera_w']) + '\n    camera_h: ' + str(args['wrist_camera_h']))
    print('Front Camera Config:\n    type: '+ str(args['front_camera_type']) + '\n    fovy: ' + str(args['front_camera_fovy']) + '\n    camera_w: ' + str(args['front_camera_w']) + '\n    camera_h: ' + str(args['front_camera_h']))
    print('\n=======================================')

    args['expert_seed'] = seed
    args['expert_data_num'] = expert_data_num

    print('\n============= Composition Config =============\n')
    args['dp_w'] = cfg.dp_w
    args['dp3_w'] = cfg.dp3_w
    print('DP Policy Weight:' , args['dp_w'])
    print('DP3 Policy Weight:', args['dp3_w'])
    print('\n=======================================')

    task = class_decorator(args['task_name'])

    st_seed = 100000 * (1+seed)
    suc_nums = []
    test_num = 100 
    topk = 1

    dp = DP(task_name, head_camera_type, 300, expert_data_num, seed)
    dp3 = DP3(cfg, checkpoint_num)

    st_seed, suc_num = test_composed_policy(task_name, task, args, dp, dp3, st_seed, test_num=test_num)
    suc_nums.append(suc_num)

    topk_success_rate = sorted(suc_nums, reverse=True)[:topk]
    save_dir = Path(f'eval_result/composed_policy/{task_name}_{head_camera_type}/{expert_data_num}')
    save_dir.mkdir(parents=True, exist_ok=True)
    current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    current_time_save = datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
    dp_w = args['dp_w']
    dp3_w = args['dp3_w']
    file_path = save_dir / f'ckpt_{checkpoint_num}_seed_{seed}_{current_time_save}_weight_{dp_w}_{dp3_w}.txt'

    if not os.path.exists(os.path.dirname(file_path)):
        os.makedirs(os.path.dirname(file_path))
    with open(file_path, 'w') as file:
        file.write(f'Timestamp: {current_time}\n\n')

        file.write(f'Checkpoint Num: {checkpoint_num}\n')
        
        file.write('Successful Rate of Diffenent checkpoints:\n')
        file.write('\n'.join(map(str, np.array(suc_nums) / test_num)))
        file.write('\n\n')
        file.write(f'TopK {topk} Success Rate (every):\n')
        file.write('\n'.join(map(str, np.array(topk_success_rate) / test_num)))
        file.write('\n\n')
        file.write(f'TopK {topk} Success Rate:\n')
        file.write(f'\n'.join(map(str, np.array(topk_success_rate) / (topk * test_num))))
        file.write('\n\n')

    print(f'Data has been saved to {file_path}')



if __name__ == "__main__":
    from test_render import Sapien_TEST
    Sapien_TEST()
    main()
    # parser = ArgumentParser()
    # parser.add_argument('task_name', type=str, default='block_hammer_beat')
    # parser.add_argument('head_camera_type', type=str)
    # parser.add_argument('expert_data_num', type=int, default=20)
    # parser.add_argument('checkpoint_num', type=int, default=1000)
    # parser.add_argument('seed', type=int, default=0)
    # usr_args = parser.parse_args()
    
    # main(usr_args)
