#-*- coding:utf-8 -*-
# Testing Diffusion Policy & Consistency Policy
#

import sys 
import os

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion.denoisers import Denoiser, append_dims, get_sigmas_karras, sample_dpmpp_2m
from diffusion.dit import *
from diffusion.edm_model import ConditionalKarrasUnet1D
from diffusion import TransformerForDiffusion, ModelType, MultiModelType
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
from diffusion_policy.rotation_transformer import RotationTransformer
from diffusion_policy.robomimic_lowdim_wrapper import RobomimicLowdimWrapper
from IPython.display import Video
from skvideo.io import vwrite
from dataset.robomimic_lowdim_dataset import RobomimicReplayLowdimDataset
from dataset.pusht_dataset import normalize_data, unnormalize_data, PushTStateDataset
from dataset.tasks import TaskTypes 
from dataset.tasks import *
from tqdm.auto import tqdm
from diffusion.model import ConditionalUnet1D
from pprint import pprint
from env import PushTEnv
from glob import glob

from vecdb import RobotFAISS

import robomimic.utils.file_utils as FileUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.obs_utils as ObsUtils
import wandb.sdk.data_types.video as wv
import collections
import numpy as np
import argparse
import random
import pathlib
import torch
import time

CONTROL_TYPE = ControlType.STATE

parser = argparse.ArgumentParser()

# Test options
parser.add_argument('-n', '--test-samples', type=int, default=100)

# General options
parser.add_argument('--checkpoint', type=str, default="./weights/t-push-diffusion-epoch50.pt")
parser.add_argument('-e', '--export_video', type=str, default="results/vis.mp4")
parser.add_argument('-d', '--dataset', type=str, default="./data/pusht_cchi_v7_replay.zarr.zip")
parser.add_argument('--max_steps', type=int, default=200)
parser.add_argument('--diffusion_timesteps', type=int, default=40)
parser.add_argument('-m', '--model_type', type=str, default="CNN")
parser.add_argument('--task_type', type=str, default="PUSHT")
parser.add_argument('--task_tag', type=str, default="")
parser.add_argument('--last-checkpoints', type=int, default=1)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--use_max_action', action='store_true')

# Karras (EDM) options
parser.add_argument('--sigma_data', type=float, default=0.5)
parser.add_argument('--sigma_sample_density_mean', type=float, default=-1.2)
parser.add_argument('--sigma_sample_density_std', type=float, default=1.2)
parser.add_argument('--sigma_max', type=float, default=80)
parser.add_argument('--sigma_min', type=float, default=0.0002)
parser.add_argument('--rho', type=float, default=7.0)

# RAG options
parser.add_argument('--index-name', type=str, default="toolhang.index")
parser.add_argument('--diffuse-rate', type=float, default=0.5)
parser.add_argument('-r', '--retrieve-every', type=int, default=1)
parser.add_argument('--rag-vp', action='store_true')
opt = parser.parse_args()

torch.manual_seed(opt.seed)
np.random.seed(opt.seed)
random.seed(opt.seed)
device = 'cuda'

if opt.task_tag == "":
    task_tag = TaskTags.NONE
elif opt.task_tag == "PH":
    task_tag = TaskTags.PH 
elif opt.task_tag == "MH":
    task_tag = TaskTags.MH
else:
    raise NotImplementedError(f"Task Tag {opt.task_tag} Not implemented")

if opt.task_type == 'PUSHT':
    task_type = TaskTypes.PUSHT
    task = PushT(ctype=CONTROL_TYPE)
elif opt.task_type == 'LIFT':
    task_type = TaskTypes.LIFT
    task = Lift(ctype=CONTROL_TYPE, tag=task_tag)
elif opt.task_type == 'CAN':
    task_type = TaskTypes.CAN
    task = Can(ctype=CONTROL_TYPE, tag=task_tag)
elif opt.task_type == 'SQUARE':
    task_type = TaskTypes.SQUARE
    task = Square(ctype=CONTROL_TYPE, tag=task_tag)
elif opt.task_type == 'TRANSPORT':
    task_type = TaskTypes.TRANSPORT
    task = Transport(ctype=CONTROL_TYPE, tag=task_tag)
elif opt.task_type == 'TOOLHANG':
    task_type = TaskTypes.TOOLHANG
    task = ToolHang(ctype=CONTROL_TYPE)
else:
    raise NotImplementedError(f"Task {opt.task_type} Not implemented")

def create_inner_model(model_type:ModelType = ModelType.CNN):
    obs_dim = task.obs_dim
    action_dim = task.action_dim

    # create network object
    if model_type == ModelType.CNN:
        inner_model = ConditionalKarrasUnet1D(
            input_dim=action_dim,
            global_cond_dim=obs_dim*opt.obs_horizon,
            diffusion_step_embed_dim=256,
            down_dims=[256,512,1024]
        )
    elif model_type == ModelType.TRANSFORMER:
        # inner_model = TransformerForDiffusion(
        #     input_dim=action_dim,
        #     output_dim=action_dim,
        #     horizon=opt.pred_horizon,
        #     n_obs_steps=opt.obs_horizon,
        #     cond_dim=obs_dim,
        #     causal_attn=True,
        #     n_cond_layers=4
        # )
        inner_model = DiT_S_2(
            input_size=action_dim,
            horizon=opt.pred_horizon, 
            obs_horizon=opt.obs_horizon, 
            obs_dim=obs_dim, 
        )
    else:
        raise NotImplementedError
    return inner_model


def create_inner_model(model_type:MultiModelType = ModelType.CNN):
    obs_dim = task.obs_dim
    action_dim = task.action_dim

    # create network object
    if model_type == MultiModelType.CNN:
        # dim = 32 + opt.obs_dim
        inner_model = ConditionalKarrasUnet1D(
            input_dim=action_dim,
            global_cond_dim=obs_dim*task.obs_horizon, # ResNet18 feature map size + Low
            diffusion_step_embed_dim=256,
            down_dims=[256,512,1024],
        )
    elif model_type == MultiModelType.MINGPT:
        inner_model = TransformerForDiffusion(
            input_dim=action_dim,
            output_dim=action_dim,
            horizon=task.pred_horizon,
            n_obs_steps=task.obs_horizon,
            cond_dim=obs_dim, # ResNet18 feature map size + Low
            causal_attn=True,
            n_cond_layers=4
        )
    elif model_type == MultiModelType.DiT_S:
        inner_model = DiT_S_2(
            input_size=action_dim,
            horizon=task.pred_horizon, 
            obs_horizon=task.obs_horizon, 
            obs_dim=obs_dim, 
        )
    elif model_type == MultiModelType.DiT_B:
        inner_model = DiT_B_2(
            input_size=action_dim,
            horizon=task.pred_horizon, 
            obs_horizon=task.obs_horizon, 
            obs_dim=obs_dim, 
        )
    elif model_type == MultiModelType.DiT_L:
        inner_model = DiT_L_2(
            input_size=action_dim,
            horizon=task.pred_horizon, 
            obs_horizon=task.obs_horizon, 
            obs_dim=obs_dim, 
        )
    elif model_type == MultiModelType.DiT_XL:
        inner_model = DiT_L_2(
            input_size=action_dim,
            horizon=task.pred_horizon, 
            obs_horizon=task.obs_horizon, 
            obs_dim=obs_dim, 
        )
    else:
        raise NotImplementedError
    return inner_model

def load_edm_model(ckpt_path, model_type:MultiModelType = MultiModelType.CNN):
    state_dict = torch.load(ckpt_path, map_location=device)
    inner_model = create_inner_model(model_type=model_type).to(device)

    edm = Denoiser(inner_model=inner_model, sigma_data=opt.sigma_data)
    edm.load_state_dict(state_dict)
    edm.eval()
    print('Pretrained EDM model loaded.')
    return edm

def create_env(
        seed:int=0, 
        enable_render:bool=True, 
        output_dir:str="results", 
        render_hw=(256,256),
        render_camera_name="agentview",
        fps:int=10,
        crf:int=22
    ):
    if task_type == TaskTypes.PUSHT:
        env = PushTEnv()
        # use a seed >200 to avoid initial states seen in the training dataset
        env.seed(seed)
        return env 
    else:
        env_meta = FileUtils.get_env_metadata_from_dataset(task.dataset_path)
        env_meta['env_kwargs']['controller_configs']['control_delta'] = False
        # disable object state observation
        env_meta['env_kwargs']['use_object_obs'] = False
        modality_mapping = collections.defaultdict(list)
        ObsUtils.initialize_obs_modality_mapping_from_dict(modality_mapping)

        ObsUtils.initialize_obs_modality_mapping_from_dict(
            {'low_dim': task.obs_keys})
        _env = EnvUtils.create_env_from_metadata(
            env_meta=env_meta,
            render=False, 
            # only way to not show collision geometry
            # is to enable render_offscreen
            # which uses a lot of RAM.
            render_offscreen=False,
            use_image_obs=False, 
        )
        # Robosuite's hard reset causes excessive memory consumption.
        # Disabled to run more envs.
        # https://github.com/ARISE-Initiative/robosuite/blob/92abf5595eddb3a845cd1093703e5a3ccd01e77e/robosuite/environments/base.py#L247-L248
        _env.hard_reset = False
        robosuite_fps = 20
        steps_per_render = max(robosuite_fps // fps, 1)

        env = VideoRecordingWrapper(
            RobomimicLowdimWrapper(
                env=_env,
                obs_keys=task.obs_keys,
                init_state=None,
                render_hw=render_hw,
                render_camera_name=render_camera_name
            ),
            video_recoder=VideoRecorder.create_h264(
            fps=fps,
            codec='h264',
            input_pix_fmt='rgb24',
            crf=crf,
            thread_type='FRAME',
            thread_count=1
            ),
            file_path=None,
            steps_per_render=steps_per_render
        )

        env.video_recoder.stop()
        env.file_path = None
        if enable_render:
            filename = pathlib.Path(output_dir).joinpath(
                'media', opt.task_type + "-" + opt.task_tag + "_" + wv.util.generate_id() + ".mp4")
            filename.parent.mkdir(parents=False, exist_ok=True)
            filename = str(filename)
            env.file_path = filename

        # switch to seed reset
        assert isinstance(env.env, RobomimicLowdimWrapper)
        env.env.env.init_state = None
        env.seed(seed)
        return env

def get_stats():
    if task_type == TaskTypes.PUSHT:
        dataset = PushTStateDataset(
            dataset_path=task.dataset_path,
            pred_horizon=task.pred_horizon,
            obs_horizon=task.obs_horizon,
            action_horizon=task.action_horizon
        )
        # save training data statistics (min, max) for each dim
        stats = dataset.stats
        return stats
    else:
        dataset = RobomimicReplayLowdimDataset(
            dataset_path=task.dataset_path,
            horizon=task.pred_horizon,
            obs_keys=task.obs_keys,
            abs_action=True,
            pad_before=1,
            pad_after=7,
        )
        return dataset

def undo_transform_action(action, rotation_transformer:RotationTransformer):
    raw_shape = action.shape
    if raw_shape[-1] == 20:
        # dual arm
        action = action.reshape(-1,2,10)

    d_rot = action.shape[-1] - 4
    pos = action[...,:3]
    rot = action[...,3:3+d_rot]
    gripper = action[...,[-1]]
    rot = rotation_transformer.inverse(rot)
    uaction = np.concatenate([
            pos, rot, gripper
    ], axis=-1)

    if raw_shape[-1] == 20:
        # dual arm
        uaction = uaction.reshape(*raw_shape[:-1], 14)

    return uaction

@torch.no_grad()
def generate_edm(model, timesteps:int, batchsize:int, nobs:torch.Tensor, model_type:ModelType = ModelType.CNN):
    obs_cond = nobs.unsqueeze(0)
    if model_type == MultiModelType.CNN:
        obs_cond = obs_cond.flatten(start_dim=1)

    # initialize action from Guassian noise
    noisy_action = torch.randn(
        (batchsize, task.pred_horizon, task.action_dim), device=device)
    naction = noisy_action * opt.sigma_max

    sigmas = get_sigmas_karras(timesteps, opt.sigma_min, opt.sigma_max, rho=opt.rho, device=device)
    # if model_type == ModelType.TRANSFORMER:
    #     sigmas = sigmas.log() / 4

    naction = sample_dpmpp_2m(model, naction, sigmas, disable=True, extra_args={'global_cond':obs_cond})
    return naction 

@torch.no_grad()
def generate_edm_steps(
        model, 
        timesteps:int, 
        batchsize:int, 
        nobs:torch.Tensor, 
        naction:torch.Tensor, 
        diffuse_rate:float = 0.5,
        model_type:ModelType = ModelType.CNN
    ):
    obs_cond = nobs.unsqueeze(0)
    if model_type == MultiModelType.CNN:
        obs_cond = obs_cond.flatten(start_dim=1)

    sigmas = get_sigmas_karras(timesteps, opt.sigma_min, opt.sigma_max, rho=opt.rho, device=device)
    start_step = int(timesteps * diffuse_rate)
    sigmas = sigmas[start_step:]
    sigma = sigmas[0]

    # initialize action from Guassian noise
    noise = torch.randn(
        (batchsize, task.pred_horizon, task.action_dim), device=device)
    naction = naction + noise * sigma
    naction = sample_dpmpp_2m(model, naction, sigmas, disable=True, extra_args={'global_cond':obs_cond})
    return naction 

@torch.no_grad()
def generate_edm_steps_ve(
        model, 
        timesteps:int, 
        batchsize:int, 
        nobs:torch.Tensor, 
        naction:torch.Tensor, 
        diffuse_rate:float = 0.5,
        model_type:ModelType = ModelType.CNN
    ):
    obs_cond = nobs.unsqueeze(0)
    if model_type == MultiModelType.CNN:
        obs_cond = obs_cond.flatten(start_dim=1)

    fixed_timesteps = int(timesteps * (1 - diffuse_rate))
    sigmas = get_sigmas_karras(fixed_timesteps, opt.sigma_min, opt.sigma_max, rho=opt.rho, device=device)
    sigma = sigmas[0] # = sigma_max 

    # initialize action from Guassian noise
    noise = torch.randn(
        (batchsize, task.pred_horizon, task.action_dim), device=device)
    naction = naction + noise * sigma
    naction = sample_dpmpp_2m(model, naction, sigmas, disable=True, extra_args={'global_cond':obs_cond})
    return naction 

def test():
    
    num_test_samples = opt.test_samples
    max_steps = opt.max_steps
    retrieve_every = opt.retrieve_every

    length = task.obs_dim * task.obs_horizon
    index_name = opt.index_name
    vecdb = RobotFAISS(index_name=index_name, vector_dimensions=length)
    vecdb.load()

    if opt.model_type == 'CNN':
        model_type = MultiModelType.CNN 
    elif opt.model_type == 'minGPT':
        model_type = MultiModelType.MINGPT
    elif opt.model_type == 'DiT_S':
        model_type = MultiModelType.DiT_S
    elif opt.model_type == 'DiT_B':
        model_type = MultiModelType.DiT_B
    elif opt.model_type == 'DiT_L':
        model_type = MultiModelType.DiT_L
    elif opt.model_type == 'DiT_XL':
        model_type = MultiModelType.DiT_XL
    else:
        raise NotImplementedError

    stats = get_stats()
    obs_horizon = task.obs_horizon
    if opt.use_max_action:
        action_horizon = task.pred_horizon - task.obs_horizon + 1 
    else:
        action_horizon = task.action_horizon

    # env_seeds = [1000 + i for i in range(num_test_samples)]
    env_seeds = task.get_eval_seeds()
    num_test_samples = len(env_seeds)
    rotation_transformer = RotationTransformer('axis_angle', 'rotation_6d')

    if opt.last_checkpoints == 1:
        weight_files = [opt.checkpoint]
    else:
        weight_files = glob(os.path.join(opt.checkpoint, 'last*'))

    global_dones = []
    global_rewards = []
    global_max_rewards = []

    print("Using RAG-VP:", opt.rag_vp)

    for checkpoint in weight_files:
        edm = load_edm_model(checkpoint, model_type=model_type) if checkpoint != "" else None 
        dones = []
        rewards = []
        max_rewards = []
        times = []
        action_steps = []
        episode_times = []
        seed_success = {}
        with tqdm(total=num_test_samples, desc=f"Eval EDM {opt.task_type}") as pbar:
            for env_step in range(num_test_samples):
                start_time = time.time()
                env = create_env(seed=env_seeds[env_step])
                env_seed = env_seeds[env_step]

                seed_success.setdefault(env_seed, False)

                obs = env.reset()
                # keep a queue of last 2 steps of observations
                obs_deque = collections.deque(
                    [obs] * task.obs_horizon, 
                    maxlen=task.obs_horizon
                )
                step_idx = 0
                done = False
                _register_action_time = 0
                inference_step_idx = 0
                with tqdm(total=max_steps, desc='Generating Actions', leave=False) as abar:
                    while not done:
                        use_rag = False
                        # for bottle_neck_period in task.bottle_neck_periods:
                        #     step_min, step_max = bottle_neck_period
                        #     if step_min <= step_idx <= step_max:
                        #         use_rag = True
                        if inference_step_idx % retrieve_every == 0:
                            use_rag = True
                        B = 1
                        obs_seq = np.stack(obs_deque)
                        if task_type == TaskTypes.PUSHT:
                            nobs = normalize_data(obs_seq, stats=stats['obs'])
                            nobs = torch.from_numpy(nobs).to(device, dtype=torch.float32)
                        else:
                            nobs = torch.from_numpy(obs_seq).to(device, dtype=torch.float32)
                            nobs = stats.normalizer['obs'].normalize(nobs).to(nobs.device)
                        
                        t1 = time.time()
                        if use_rag:
                            query_vector = nobs.cpu().reshape(-1).detach().numpy()
                            expert_action = vecdb.search(query_vector, k=1)[0]
                            expert_action = torch.from_numpy(expert_action).to(device).reshape(1, -1, task.action_dim)
                            if not opt.rag_vp:
                                naction = generate_edm_steps_ve(
                                    edm,
                                    opt.diffusion_timesteps, 
                                    naction=expert_action,
                                    batchsize=B, 
                                    nobs=nobs, 
                                    diffuse_rate=opt.diffuse_rate,
                                    model_type=model_type
                                )
                            else:
                                naction = generate_edm_steps(
                                    edm,
                                    opt.diffusion_timesteps, 
                                    naction=expert_action,
                                    batchsize=B, 
                                    nobs=nobs, 
                                    diffuse_rate=opt.diffuse_rate,
                                    model_type=model_type
                                )
                        else:
                            naction = generate_edm(edm, opt.diffusion_timesteps, batchsize=B, nobs=nobs, model_type=model_type)
                        if task_type != TaskTypes.PUSHT:
                            naction = stats.normalizer['action'].unnormalize(naction)
                        naction = naction.detach().to('cpu').numpy()
                        naction = naction[0]
                        t2 = time.time()
                        _register_action_time += t2 - t1

                        if task_type == TaskTypes.PUSHT:
                            action_pred = unnormalize_data(naction, stats=stats['action'])
                        else:
                            action_pred = undo_transform_action(naction, rotation_transformer=rotation_transformer)

                        start = obs_horizon - 1
                        end = start + action_horizon
                        action = action_pred[start:end,:]
                        inference_step_idx += 1

                        for i in range(len(action)):
                            obs, reward, done, info = env.step(action[i])
                            obs_deque.append(obs)

                            step_idx += 1
                            abar.update(1)
                            abar.set_postfix(reward=reward)
                            if step_idx > max_steps:
                                done = True
                            if done or reward == 1:
                                done = True
                                break
                end_time = time.time()
                episode_time = end_time - start_time
                episode_times.append(episode_time)
                rewards.append(reward) # final reward 
                max_rewards.append(np.max(rewards))
                dones.append(1 if reward > 0 else 0)
                if reward > 0:
                    seed_success[env_seed] = True
                # times.append(end_time-start_time)
                times.append(_register_action_time)
                action_steps.append(step_idx)

                pbar.update(1)
                pbar.set_postfix(done_num=np.sum(dones))
        task_time = np.mean(times)
        done_rate = np.sum(dones) / num_test_samples 
        reward_mean = np.mean(rewards)
        action_time = np.mean([t / a for a, t in zip(action_steps, times)])
        print("SUCCESS per seeds:")
        pprint(seed_success)

        _ = env.reset()
        task_time = np.mean(episode_times)
        done_rate = np.sum(dones) / num_test_samples 
        reward_mean = np.mean(rewards)
        max_reward_mean = np.mean(max_rewards)

        global_dones.append(done_rate)
        global_rewards.append(reward_mean)
        global_max_rewards.append(max_reward_mean)

        print("EDM Test result")
        print("Time [Sec/Task] : ", task_time)
        print("Done Rate : ", done_rate)
        print("Reward Mean : ", reward_mean)
        print("Max Rewared :", max_reward_mean)
        print("Action Time [Sec/Action] : ", action_time)
        print("Inference Time [Sec/Inference] : ", np.mean(times))
        torch.cuda.empty_cache()

    print("Global Dones:", np.mean(global_dones))
    print("Global Rewareds:", np.mean(global_rewards))
    print("Global MaxRewards:", np.mean(global_max_rewards))

if __name__ == '__main__':
    test()