import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from d4rl.pointmaze import waypoint_controller
from d4rl.pointmaze import maze_model, maze_layouts
from collections import defaultdict, OrderedDict
import numpy as np
import pickle
import gzip
import h5py
import argparse
from rlf.args import str2bool
import tqdm
import gym
import cv2

from iq_learn.agent.ppo_utils.gym_env import DictWrapper, GymWrapper
from demo_collection.utils.utils import set_up_log_dirs, logging, make_envs
import torch

import acil_envs

class Rollout:
    """
    Rollout storing an episode.
    """
    def __init__(self):
        """Initialize rollout storage."""
        self.obses = []
        self.next_obses = []
        self.actions = []
        self.dones = []
        self.ep_found_goals = []

    def add(self, obs, next_obs, action, done, ep_found_goal):
        """Add a transition to the rollout buffer."""
        self.obses.append(obs)
        self.next_obses.append(next_obs)
        self.actions.append(action)
        self.dones.append(done)
        self.ep_found_goals.append(ep_found_goal)

    def get(self):
        """Return the collected rollout data."""
        return {
            "obs": self.obses,
            "next_obs": self.next_obses,
            "actions": self.actions,
            "dones": self.dones,
            "ep_found_goals": self.ep_found_goals,
        }

    def clear(self):
        """Reset the rollout buffer."""
        self.obses = []
        self.next_obses = []
        self.actions = []
        self.dones = []
        self.ep_found_goals = []

def npify(data):
    for k in data:
        if k == 'terminals':
            dtype = np.bool_
        else:
            dtype = np.float32

        data[k] = np.array(data[k], dtype=dtype)


def sample_env_and_controller(args):
    # layout_str = maze_layouts.rand_layout(args.rand_maze_size)
    # env = maze_model.MazeEnv(layout_str, agent_centric_view=args.agent_centric)
    # env = gym.make(args.maze)
    # env = GymWrapper(env=env, from_pixels=False, height=480, width=480, channels_first=True, frame_skip=1, return_state=False)
    # env = DictWrapper(env, return_state=False)
    env = make_envs(args)
    try:
        layout_str = env.env.env.env.str_maze_spec
    except AttributeError:
        layout_str = env.env.env.str_maze_spec

    if args.mz_box_constrained:
        action_ub = args.mz_ub
        print(action_ub)
    else:
        action_ub = 1.0
    controller = waypoint_controller.WaypointController(layout_str, action_ub=action_ub)
    return env, controller


def reset_env(env, agent_centric=False):
    s = env.reset()  # return obs
    if agent_centric:
        [env.render(mode='rgb_array') for _ in range(100)]  # so that camera can catch up with agent
    return s

def resize_video(images, dim=64):
    """Resize a video in numpy array form to target dimension."""
    ret = np.zeros((images.shape[0], dim, dim, 3))

    for i in range(images.shape[0]):
        ret[i] = cv2.resize(images[i], dsize=(dim, dim),
                            interpolation=cv2.INTER_CUBIC)

    return ret.astype(np.uint8)
# import skvideo.io
def save_video(file_name, frames, fps=20, video_format='mp4', dim=256):
    images = resize_video(np.array(frames), dim=dim)

    # save video
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video = cv2.VideoWriter(file_name, fourcc, 20.0, (dim, dim))
    for i in range(images.shape[0]):
        video.write(images[i])
    video.release()

def add_args(parser):
    parser.add_argument('--render', action='store_true', help='Render trajectories')
    parser.add_argument('--agent_centric', action='store_true', help='Whether agent-centric images are rendered.')
    parser.add_argument('--save_images', action='store_true', default=True, help='Whether rendered images are saved.')
    parser.add_argument('--data_dir', type=str, default="./expert_datasets/constraint_analysis", help='Base directory for dataset')

    # demo collection
    parser.add_argument('--min_traj_len', type=int, default=20, help='Min number of samples per trajectory')
    parser.add_argument('--num_trajs', type=int, default=400, help='Number of trajectories to collect')
    parser.add_argument('--only_save_success_trajs', type=str2bool, default=True, help='Only save successful trajectories')


    # env related
    parser.add_argument('--env_name', type=str, default='maze2d-medium-v1', help='Maze type')
    parser.add_argument('--mz-reward-type', type=str, default='dense', help="dense or sparse")
    parser.add_argument('--mz-box-constrained', type=str2bool, default=False, help="")
    parser.add_argument('--mz-ub', type=float, default=0.1, help="")
    parser.add_argument('--mz-safe-constrained', type=str2bool, default=False, help="")


    parser.add_argument('--dim-filter', type=float, default=1.0, help="how many percent of the dimensions to keep")

def main():
    parser = argparse.ArgumentParser()
    add_args(parser)
    args = parser.parse_args()
    os.makedirs(args.data_dir, exist_ok=True)

    env, controller = sample_env_and_controller(args)
    s = reset_env(env, agent_centric=args.agent_centric)

    obses, next_obses, actions, dones, ep_found_goals = [], [], [], [], []
    cnt = 0

    # for episode in range(args.num_trajs):
    episode = 0
    while episode < args.num_trajs:
        # frame_buffer = []
        logging(f"Episode: {episode + 1}")
        rollout = Rollout()
        success = False
        s = reset_env(env, agent_centric=args.agent_centric)
        
        while True:
            # position = s["ob"][0:2]
            # velocity = s["ob"][2:4]
            position = s[0:2]
            velocity = s[2:4]
            target = tuple(s[4:6])
            
            try:
                act, done = controller.get_action(position, velocity, target)
            except ValueError:
                logging("Failed to find valid path to goal. Resetting environment.")
                break

            act = np.clip(act, -1.0, 1.0)

            ns, reward, done, info = env.step(act)
            # frame_buffer.append(env.render(mode='rgb_array'))

            success = success or info.get('goal_achieved', False)

            # store the real action taken by the agent
            if args.mz_safe_constrained:
                save_act = info['real_action']
            else:
                save_act = act
            rollout.add(s, ns, save_act, done, success)

            if done:
                episode += 1
                if args.only_save_success_trajs:
                    if len(rollout.actions) > args.min_traj_len and success:
                        # add the trajectory to the main rollout
                        obses.extend(rollout.obses)
                        next_obses.extend(rollout.next_obses)
                        actions.extend(rollout.actions)
                        dones.extend(rollout.dones)
                        ep_found_goals.extend(rollout.ep_found_goals)
                        cnt += 1
                        logging(f"Trajectory {cnt} collected successfully.")
                    else:
                        episode -= 1    
                        logging("Failed to find the goal")
                else:
                    data = rollout.get()
                    obses.extend(data['obs'])
                    next_obses.extend(data['next_obs'])
                    actions.extend(data['actions'])
                    dones.extend(data['dones'])
                    ep_found_goals.extend(data['ep_found_goals'])
                    cnt += 1
                    if len(rollout.actions) > args.min_traj_len and success:
                        pass
                    else:
                        logging("Failed to find the goal")
                    logging(f"Trajectory {cnt} collected successfully.")
                rollout.clear()

                break
            s = ns
        
    # Save collected trajectories
    save_trajectories(obses, next_obses, actions, dones, ep_found_goals, args)

def save_demos(env, args, rollouts, num_trajs):
    dir_name = 'demos-%s-noisy' % args.maze if args.noisy else 'demos-%s' % args.maze
    if args.batch_idx >= 0:
        dir_name = os.path.join(dir_name, "batch_{}".format(args.batch_idx))
    file_name = os.path.join(args.data_dir, dir_name, "rollout_{}_trajs.pkl".format(num_trajs))

    new_rollouts = []
    for rollout in rollouts:
        new_rollout = {
            "obs": rollout["ob"],
            "actions": rollout["ac"],
            "dones": rollout["done"],
        }
        new_rollouts.append(new_rollout)

    logging("[*] Generating demo: {}".format(file_name))
    with open(file_name, "wb") as f:
        pickle.dump(new_rollouts, f)

def load_demos(file_name):
    # check saved pickle file
    num_demos = 0
    _data = []
    with open(file_name, "rb") as f:
        demos = pickle.load(f)
        # if not isinstance(demos, list):
        #     demos = [demos]
        # for demo in demos:
        #     if len(demo["obs"]) != len(demo["actions"]) + 1:
        #         logger.error(
        #             "Mismatch in # of observations (%d) and actions (%d) (%s)",
        #             len(demo["obs"]),
        #             len(demo["actions"]),
        #             file_name,
        #         )
        #         continue
    
        #     num_demos += 1
    
        #     length = len(demo["actions"])
        #     for i in range(length):
        #         transition = {
        #             "ob": demo["obs"][i],
        #             "ob_next": demo["obs"][i + 1],
        #         }
        #         if isinstance(demo["actions"][i], dict):
        #             transition["ac"] = demo["actions"][i]
        #         else:
        #             transition["ac"] = gym.spaces.unflatten(
        #                 env.action_space, demo["actions"][i]
        #             )
        #         if "rewards" in demo:
        #             transition["rew"] = demo["rewards"][i]
        #         else:
        #             transition["rew"] = 0.0
        #         if "dones" in demo:
        #             transition["done"] = int(demo["dones"][i])
        #         else:
        #             transition["done"] = 1 if i + 1 == length else 0
        #         _data.append(transition)
    logging("Load %d demonstrations with %d states", num_demos, len(_data))

def save_trajectories(obses, next_obses, actions, dones, ep_found_goals, args):
    """Save collected trajectories to a .pt file."""
    only_save_success_trajs = args.only_save_success_trajs
    osst_suffix = '_only_success' if only_save_success_trajs else ''
    save_path = os.path.join(args.data_dir, f'{args.env_name}_{args.num_trajs}{osst_suffix}.pt')
    data = {
        'obs': torch.tensor(np.array(obses)),
        'next_obs': torch.tensor(np.array(next_obses)),
        'actions': torch.tensor(np.array(actions)),
        'done': torch.tensor(np.array(dones)),
        'ep_found_goal': torch.tensor(np.array(ep_found_goals))
    }
    torch.save(data, save_path)
    logging(f"Trajectories saved at {save_path}")


if __name__ == "__main__":
    main()

    # # check h5 file
    # f = h5py.File("./maze2d-medium/rollout_0.h5", "r")
    # # Print all root level object names (aka keys)
    # # these can be group or dataset names
    # print("Keys: %s" % f.keys())
    # # get first object name/key; may or may NOT be a group
    # a_group_key = list(f.keys())[0]
    #
    # # get the object type for a_group_key: usually group or dataset
    # print(type(f[a_group_key]))
    #
    # # If a_group_key is a group name,
    # # this gets the object names in the group and returns as a list
    # data = list(f[a_group_key])
