import argparse
import time
# from common import (
#     INITIAL_QPOS,
#     GRIPPER_CLOSE_STROKE,
#     NUM_PCD_POINTS,
#     PAD_PCD_IF_LESS,
#     PCD_X_RANGE,
#     PCD_Y_RANGE,
#     PCD_Z_RANGE,
#     MOBILE_BASE_VEL_ACTION_MIN,
#     MOBILE_BASE_VEL_ACTION_MAX,
#     GRIPPER_HALF_WIDTH,
#     HORIZON_STEPS,
#     ACTION_REPEAT,
# )
import numpy as np
import torch
# from brs_ctrl.robot_interface import R1Interface
# from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
import brs_algo.utils as U
from brs_algo.learning.policy import WBVIMAPolicy
# from brs_algo.rollout.asynced_rollout import R1AsyncedRollout


import argparse
import json
import numpy as np
import time
import datetime
import os
import shutil
import psutil
import sys
import socket
import traceback
import random
import imageio
import numpy as np
from copy import deepcopy

from collections import OrderedDict
import sys
from io import StringIO

import torch
from torch.utils.data import DataLoader

import robomimic
from robomimic.utils.file_utils import get_env_metadata_from_dataset
import robomimic.macros as Macros
import robomimic.utils.train_utils as TrainUtils
import robomimic.utils.torch_utils as TorchUtils
import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.file_utils as FileUtils
from robomimic.config import config_factory
from robomimic.algo import algo_factory, RolloutPolicy
from robomimic.utils.log_utils import PrintLogger, DataLogger, flush_warnings
from robomimic.utils.file_utils import load_dict_from_checkpoint
# from doppelmaker.utils.robomimic_utils import get_epochs_trained, doppel_get_exp_dir, get_env_all_objects_initialization_from_dataset
import omnigibson as og
from omnigibson.objects import *

import mimicgen
import robomimic.utils.env_utils as EnvUtils
import mimicgen.utils.file_utils as MG_FileUtils
import mimicgen.utils.robomimic_utils as RobomimicUtils

from mimicgen.configs import MG_TaskSpec
from mimicgen.configs import config_factory as MG_ConfigFactory
from mimicgen.datagen.data_generator import DataGenerator
from mimicgen.env_interfaces.base import make_interface
from mimicgen.train_scripts.plot_init_states import load_init_states
import h5py

import yaml
from brs_algo.rollout.og_rollout import rollout_with_stats
import wandb
import copy
from matplotlib import pyplot as plt
import pickle

DEVICE = torch.device("cuda:0")
NUM_LATEST_OBS = 2
HORIZON = 16
T_action_prediction = 8


def load_openpi_model():

    from openpi.shared.eval_b1k_wrapper import OpenPIWrapper

    # need to launch the openpi server first
    openpi_policy = OpenPIWrapper(
    host="",
    port=8000,
    text_prompt="pick up the mug and place in the sink",
    control_mode="receeding_temporal",
)

    return openpi_policy


def load_brs_model(args, train_config, epoch_name):
    prop_keys = train_config["module"]['policy']['prop_keys']
    policy = WBVIMAPolicy(
        prop_dim=train_config["module"]['policy']['prop_dim'],
        prop_keys=prop_keys,
        prop_mlp_hidden_depth=2,
        prop_mlp_hidden_dim=256,
        pointnet_n_coordinates=3,
        pointnet_n_color=3,
        pointnet_hidden_depth=train_config["module"]['policy']['pointnet_hidden_depth'],
        pointnet_hidden_dim=train_config["module"]['policy']['pointnet_hidden_dim'],
        action_keys=[
            "mobile_base",
            "torso",
            "left_arm",
            "left_gripper",
            "right_arm",
            "right_gripper",
        ],
        action_key_dims={
            "mobile_base": 3,
            "torso": 4,
            "left_arm": 6,
            "left_gripper": 1,
            "right_arm": 6,
            "right_gripper": 1,
        },
        num_latest_obs=NUM_LATEST_OBS,
        use_modality_type_tokens=False,
        xf_n_embd=train_config["module"]['policy']['xf_n_embd'],
        xf_n_layer=train_config["module"]['policy']['xf_n_layer'],
        xf_n_head=8,
        xf_dropout_rate=0.0, # why need dropout during evaluation?
        xf_use_geglu=True,
        learnable_action_readout_token=False,
        action_dim=21,
        action_prediction_horizon=T_action_prediction,
        diffusion_step_embed_dim=train_config["module"]['policy']['diffusion_step_embed_dim'],
        unet_down_dims=train_config["module"]['policy']['unet_down_dims'],
        unet_kernel_size=5,
        unet_n_groups=8,
        unet_cond_predict_scale=True,
        noise_scheduler=DDIMScheduler(
            num_train_timesteps=100,
            beta_start=0.0001,
            beta_end=0.02,
            beta_schedule="squaredcos_cap_v2",
            clip_sample=True,
            set_alpha_to_one=True,
            steps_offset=0,
            prediction_type="epsilon",
        ),
        noise_scheduler_step_kwargs=None,
        num_denoise_steps_per_inference=16,
        no_pcd=train_config["module"]['policy']['no_pcd'],
        flash_attention=train_config["module"]['policy']['flash_attention'],
    )


    args.ckpt_path = os.path.join(args.load_checkpoint_folder, "ckpt/", "{}".format(epoch_name))

    U.load_state_dict(
        policy,
        U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"],
        strip_prefix="policy.",
        strict=True,
    )
    policy = policy.to(DEVICE)
    policy.eval()

    print('policy successfully loaded')
    return policy


def construct_og_env(args, mg_config, train_config):
    
    # get environment metadata from dataset
    source_dataset_path = os.path.expandvars(os.path.expanduser(mg_config['experiment']['source']['dataset_path'])) # teleop demo
    ds_format = 'robomimic'    
    env_meta = get_env_metadata_from_dataset(dataset_path=source_dataset_path, ds_format=ds_format)

    envs = OrderedDict()

    if args.headless:
        from omnigibson.macros import gm
        gm.HEADLESS = True

    env_name = f'r1_tidy_table_{args.random_type}'
    env = RobomimicUtils.create_env(
        env_meta=env_meta,
        env_class=None,
        env_name=env_name,
        render=False, 
        render_offscreen=False,
        use_image_obs=False,
        use_depth_obs=False,
        init_curobo=False,
        policy_rollout=True,
        manipulation_only=args.manipulation_only,
        real_robot_mode=False, # for tidy table, we don't need to use the real robot mode
        white_r1=args.white_r1,
        pcd_heuristics=args.pcd_heuristics,
        baseline=args.baseline,
        model_type=args.model_type,
    )

    env.with_color = True
    all_sensor_info = env.sensor_setup()

    # read the controller and action index
    robot = env.env.robots[0]
    for name, controller in robot._controllers.items():
        action_idx = robot.controller_action_idx[name]
        print('name', name, 'controller', controller, 'action_idx', action_idx)

    env = EnvUtils.wrap_env_from_config_brs(env, config=train_config) # apply environment warpper, if applicable
    envs[env.name] = env
    print(envs[env.name])
    
    print("og environment successfully constructed")

    return envs


def get_min_val_loss_ckpt(file_list):
    min_val_loss = 10000
    min_val_loss_ckpt = 'last-v1.pth'
    for file_name in file_list:
        if "val_l1_" in file_name:
            val_loss = file_name.split("val_l1_")[-1].split(".pth")[0]
            if float(val_loss) < min_val_loss:
                min_val_loss = float(val_loss)
                min_val_loss_ckpt = file_name
    if min_val_loss > 0.9:
        return None
    return min_val_loss_ckpt


def get_latest_ckpt(file_list):
    latest_ckpt = 'last-v1.pth'
    max_epoch_index = 0
    for file_name in file_list:
        if "val_l1_" in file_name:
            epoch_index = int(file_name.split('-')[0].split('epoch')[-1])
            if epoch_index > max_epoch_index:
                max_epoch_index = copy.deepcopy(epoch_index)
                latest_ckpt = copy.deepcopy(file_name)
    return latest_ckpt


def get_epoch_list(args):
    epoch_list = []
    if args.single_epoch is not None:
        epoch_list = []
        epoch_list.append(args.single_epoch)
    elif args.eval_all:
        # read epochs from the checkpoint folder
        epoch_list = [f for f in os.listdir(os.path.join(args.load_checkpoint_folder, "ckpt")) if f.endswith('.pth')]
        epoch_list = sorted(epoch_list)
        print('epoch_list', epoch_list)
    
    elif args.eval_last_and_val_min_ckpt:
        epoch_list = [f for f in os.listdir(os.path.join(args.load_checkpoint_folder, "ckpt")) if f.endswith('.pth')]
        last_ckpt = get_latest_ckpt(epoch_list)

        epoch_list = [f for f in os.listdir(os.path.join(args.load_checkpoint_folder, "ckpt")) if f.endswith('.pth') and 'val' in f]
        # sort based on the validation value
        epoch_list.sort(key=lambda x: x.split('_')[-1].split('.pth')[0])
        min_val_ckpt = epoch_list[0]
        epoch_list = [min_val_ckpt, last_ckpt]
        epoch_list = sorted(set(epoch_list))
        print('epoch_list', epoch_list)
    
    elif args.eval_last_5:
        epoch_list = [f for f in os.listdir(os.path.join(args.load_checkpoint_folder, "ckpt")) if f.endswith('.pth') and 'val_l1_' in f]
        epoch_index_list = [int(epoch.split('-')[0].split('epoch')[-1]) for epoch in epoch_list]
        order = np.argsort(epoch_index_list)
        new_epoch_list = []
        for i in order:
            new_epoch_list.append(epoch_list[i])
        epoch_list = new_epoch_list[-5:]
        print('epoch_list', epoch_list)
    
    elif args.eval_uniform_sample:
        epoch_list = [f for f in os.listdir(os.path.join(args.load_checkpoint_folder, "ckpt")) if f.endswith('.pth') and 'val_l1_' in f]
        epoch_index_list = [int(epoch.split('-')[0].split('epoch')[-1]) for epoch in epoch_list]
        order = np.argsort(epoch_index_list)
        new_epoch_list = []
        for i in order:
            new_epoch_list.append(epoch_list[i])
        epoch_list = [epoch for epoch in new_epoch_list if int(epoch.split('-')[0].split('epoch')[-1]) > args.eval_start_epoch]
        epoch_list = epoch_list[::args.eval_interval]
        print('epoch_list', epoch_list)
    
    return epoch_list



def plot_init_states(object_init_states, num_episodes, run_folder):
    # plot the initial states of the objects
    plt.figure(figsize=(10, 10))
    for i in range(num_episodes):
        obj_init_state = object_init_states[i]
        plt.scatter(obj_init_state[0], obj_init_state[1])
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Initial states of the objects')
    plt.savefig(os.path.join(run_folder, "eval_init_states.png"))


def wandb_eval_init_metrics(run_id, project_name='brs'):
    wandb.init(project=project_name, name=run_id, id=run_id, resume="allow")
    wandb.define_metric("eval_epoch")
    wandb.define_metric("rollout/lift_Success_Rate", step_metric="eval_epoch")
    wandb.define_metric("rollout/touching_Success_Rate", step_metric="eval_epoch")
    wandb.define_metric("rollout/bddl_Success_Rate", step_metric="eval_epoch")
    wandb.define_metric("rollout/time", step_metric="eval_epoch")
    wandb.define_metric("rollout/Time_Episode", step_metric="eval_epoch")
    wandb.define_metric("rollout/Success_Rate", step_metric="eval_epoch")
    wandb.define_metric("rollout/Return", step_metric="eval_epoch")
    wandb.define_metric("rollout/Horizon", step_metric="eval_epoch")
    wandb.define_metric("rollout/Exception_Rate", step_metric="eval_epoch")


def og_rollout(args, mg_config, train_config):
    if 'no_heuristic' in train_config['data_dir']:
        args.pcd_heuristics = False
    else:
        args.pcd_heuristics = True

    assert args.random_type in ['D0', 'D1'], "random_type should be either D0 or D1"
    assert args.baseline in ['momagen', 'mimicgen', 'skillgen'], "baseline should be either momagen, mimicgen, or skillgen"
    
    if args.random_type == 'D0':
        obj_init_states_file = "./datasets/r1_tidy_table/r1_tidy_table_cup_init_pose_val_D0.npy"
    elif args.random_type == 'D1':
        obj_init_states_file = "./datasets/r1_tidy_table/r1_tidy_table_cup_init_pose_val_D1.npy"
    else:
        print('TODO, need to get the validation dataset for random type', args.random_type)
        breakpoint()
    object_init_states = np.load(obj_init_states_file, allow_pickle=True)
    object_init_states_val = object_init_states

    # get epoch names to evaluate
    if args.model_type == 'brs':
        epoch_list = get_epoch_list(args)
        print("epoch_list", epoch_list)

        # check whether the model is already evaluated, if so, skip the evaluated epochs
        save_file_name = os.path.join(args.load_checkpoint_folder, "{}_{}.pkl".format('eval_log', 'val'))
        if os.path.exists(save_file_name):
            # load the existing file and append the new data
            with open(save_file_name, 'rb') as f:
                existing_data = pickle.load(f)
            evaluated_epochs = existing_data.keys()
            new_epoch_list = []
            for epoch in epoch_list:
                epoch_index = epoch.split('-')[0].split('epoch')[-1]
                if epoch_index not in evaluated_epochs:
                    new_epoch_list.append(epoch)
            epoch_list = new_epoch_list
            print("epoch_list that have not been evaluated", epoch_list)
        
        # init wandb run, resume if already exists
        run_id = args.load_checkpoint_folder.split('/')[-1]
        wandb_eval_init_metrics(run_id=run_id, project_name=args.wandb_project_name)
    else:
        epoch_list = ['ridy_table_pi_D0']
        save_file_name = os.path.join(args.load_checkpoint_folder, "{}_{}_pi.pkl".format('eval_log', 'openpi'))



    # construct the environment
    envs = construct_og_env(args, mg_config, train_config)

    # start rollout in the og environment 
    video_dir = os.path.join(args.load_checkpoint_folder, "videos")
    if not os.path.exists(video_dir):
        os.makedirs(video_dir)
    
    all_epoch_rollout_logs_save_to_ext = {}

    for epoch_name in epoch_list:
        # load the policy
        if args.model_type == 'brs':
            rollout_policy = load_brs_model(args, train_config, epoch_name)
        elif args.model_type == 'openpi':
            print('waiting for loading openpi')
            rollout_policy = load_openpi_model()

        # evaluate on test 
        video_str = epoch_name+'_val.mp4'
        video_path =  { k : os.path.join(video_dir, "{}{}".format(k, video_str)) for k in envs }
        print('should this be changed to D1 instead of D0')
        video_path = video_path[f'r1_tidy_table_{args.random_type}']
        if os.path.exists(video_path):
            print('val video already exists, skip this epoch', epoch_name)
        else:
            all_rollout_logs, video_paths = rollout_with_stats(
                policy=rollout_policy,
                envs=envs,
                horizon=args.horizon,
                use_goals=False,
                num_episodes=args.num_episodes,
                render=False,
                video_dir=video_dir,
                epoch=epoch_name,
                video_skip=1, #config.experiment.get("video_skip", 5),
                terminate_on_success=True,
                demo_actions=None,
                check_action_plot=None,
                verbose=True,
                init_states_list=object_init_states_val,
                train_init_states=False,
                val_init_states=True,
                model_type=args.model_type,
                mobile_base_vel_action_min=np.array(train_config['data_module']['mobile_base_vel_action_min']),
                mobile_base_vel_action_max=np.array(train_config['data_module']['mobile_base_vel_action_max']),
            )
            log_data = all_rollout_logs['r1_tidy_table_D0']
            epoch_index = epoch_name.split('-')[0].split('epoch')[-1]
            if args.model_type == 'brs':
                for key in log_data:
                    wandb.log({'rollout/'+key: log_data[key], 'eval_epoch': int(epoch_index)})
                    print('logging key: rollout', key, 'value', log_data[key], 'epoch_index', epoch_index)

            # save data locally
            save_file_name = os.path.join(args.load_checkpoint_folder, "{}_{}.pkl".format('eval_log', 'val'))
            if args.model_type == 'openpi':
                save_file_name = os.path.join(args.load_checkpoint_folder, "{}_{}_pi.pkl".format('eval_log', 'openpi'))
            if not os.path.exists(save_file_name):
                with open(save_file_name, 'wb') as f:
                    pickle.dump({str(epoch_index): log_data}, f)
                print('saved rollout logs to', save_file_name)
            else:
                # load the existing file and append the new data
                with open(save_file_name, 'rb') as f:
                    existing_data = pickle.load(f)
                existing_data[str(epoch_index)] = log_data
                with open(save_file_name, 'wb') as f:
                    pickle.dump(existing_data, f)
                print('appended rollout logs to', save_file_name)
                
    wandb.finish()
    time.sleep(10)
    import omnigibson as og
    og.shutdown()



def main(args):

    # load mg_config 
    with open(args.mg_config, "r") as f:
        ext_cfg = json.load(f)
        # config generator from robomimic generates this part of config unused by MimicGen
        if "meta" in ext_cfg:
            del ext_cfg["meta"]
    
    # load yaml config file
    config_path = os.path.join(args.load_checkpoint_folder, "conf.yaml")
    with open(config_path, "r") as f:
        train_config = yaml.safe_load(f)

    print('mg config and train config loaded')

    important_stats = None
    try:
        important_stats = og_rollout(args=args, mg_config=ext_cfg, train_config=train_config)
    except Exception as e:
        res_str = "run failed with error:\n{}\n\n{}".format(e, traceback.format_exc())
        print(res_str)



if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    ### brs specific arguments
    # parser.add_argument("--ckpt_path", type=str, required=True)
    parser.add_argument("--action_execute_start_idx", type=int, default=1)

    ### MimicGen specific arguments
    # External config file that overwrites default config
    parser.add_argument(
        "--config",
        type=str,
        default=None,
        help="(optional) path to a config json that will be used to override the default settings. \
            If omitted, default settings are used. This is the preferred way to run experiments.",
    )

    parser.add_argument(
        "--model_type",
        type=str,
        default='brs',
        help="(optional) the type of model to use, either brs or openpi",
    )

    # External config file that overwrites default config
    parser.add_argument(
        "--mg_config",
        type=str,
        default=None,
        help="path to a mimicgen config json",
    )

    # Algorithm Name
    parser.add_argument(
        "--algo",
        type=str,
        help="(optional) name of algorithm to run. Only needs to be provided if --config is not provided",
    )

    # Experiment Name (for tensorboard, saving models, etc.)
    parser.add_argument(
        "--name",
        type=str,
        default=None,
        help="(optional) if provided, override the experiment name defined in the config",
    )

    # Dataset path, to override the one in the config
    parser.add_argument(
        "--dataset",
        type=str,
        default=None,
        help="(optional) if provided, override the dataset path defined in the config",
    )

    # Output path, to override the one in the config
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="(optional) if provided, override the output folder path defined in the config",
    )

    # force delete the experiment folder if it exists
    parser.add_argument(
        "--auto-remove-exp",
        action='store_true',
        help="force delete the experiment folder if it exists"
    )

    # debug mode
    parser.add_argument(
        "--debug",
        action='store_true',
        help="set this flag to plot actions debugging purposes"
    )

    parser.add_argument(
        "--load_checkpoint_folder",
        type=str,
        default=None,
        help="the checkpoint folder used to load the model",
    )

    parser.add_argument(
        "--eval_start_epoch",
        type=int,
        default=300,
        help="the start epoch to evaluate the model",
    )

    parser.add_argument(
        "--manipulation_only",
        action='store_true',
        help="whether to run the simulation in manipulation only mode",
    )

    parser.add_argument(
        "--eval_on_train_init_states",
        action='store_true',
        help="whether to initialize the environment with the initial states in the training dataset",
    )

    parser.add_argument(
        "--eval_on_val_init_states",
        action='store_true',
        help="whether to initialize the environment with the initial states in the validation dataset",
    )

    parser.add_argument(
        "--eval_all",
        action='store_true',
        help="whether to initialize the environment with the initial states in the validation dataset",
    )

    parser.add_argument(
        "--eval_last_and_val_min_ckpt",
        action='store_true',
        help="set this flag to run policy rollout",
    )

    parser.add_argument(
        "--eval_last_5",
        action='store_true',
        help="set this flag to run policy rollout, evaluate the last 5 saved validation epochs",
    )

    parser.add_argument(
        "--eval_uniform_sample",
        action='store_true',
        help="uniformly sample epoch",
    )

    parser.add_argument(
        "--eval_interval",
        type=int,
        default=5,
        help="the interval of epochs to evaluate the model",
    )

    parser.add_argument(
        "--eval_val_min",
        action='store_true',
        help="set this flag to run policy rollout",
    )

    parser.add_argument(
        "--headless",
        action='store_true',
        help="whether to run the simulation in headless mode",
    )

    parser.add_argument(
        "--sample_epoch",
        action='store_true',
        help="set this flag to run policy rollout",
    )

    parser.add_argument(
        "--num_episodes",
        type=int,
        default=20,
        help="number of episodes to evaluate the model",
    )

    parser.add_argument(
        "--horizon",
        type=int,
        default=1000,
        help="horizon length of each episode",
    )

    parser.add_argument(
        "--white_r1",
        action='store_true',
        help="whether to change to white r1",
    )

    parser.add_argument(
        "--real_robot_mode",
        action='store_true',
        help="whether to change to old r1 qpos",
    )


    parser.add_argument(
        "--single_epoch",
        type=str,
        default=None,
        help="the single model to evaluate",
    )

    parser.add_argument(
        "--wandb_project_name",
        type=str,
        default='brs-tidy-table',
        help="the wandb project name",
    )

    parser.add_argument(
        "--random_type",
        type=str,
        default='D0',
        help="the random type to use for the initial states",
    )

    parser.add_argument(
        "--baseline",
        type=str,
        default='momagen',
        help="the baselines can be momagen, mimicgen, or skillgen",
    )

    args = parser.parse_args()

    main(args)
