import datetime
import sys
import os
import uuid

from rlf.envs.env_interface import get_env_interface
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from d4rl.pointmaze import waypoint_controller
from d4rl.pointmaze import maze_model, maze_layouts
from collections import defaultdict, OrderedDict
import numpy as np
import pickle
import gzip
import h5py
import argparse
from rlf.args import str2bool
import tqdm
import gym
import cv2

from iq_learn.agent.ppo_utils.gym_env import DictWrapper, GymWrapper
from demo_collection.utils.utils import set_up_log_dirs, logging, make_envs
import torch

import goal_prox.envs.gridworld
import goal_prox.gym_minigrid.envs.fourrooms_long_hypotenuse

import acil_envs
import rlf.rl.utils as rutils

import os.path as osp
from iq_learn.utils.utils import gen_frame, save_video

class Rollout:
    """
    Rollout storing an episode.
    """
    def __init__(self):
        """Initialize rollout storage."""
        self.obses = []
        self.next_obses = []
        self.actions = []
        self.dones = []
        self.ep_found_goals = []

    def add(self, obs, next_obs, action, done, ep_found_goal):
        """Add a transition to the rollout buffer."""
        self.obses.append(obs)
        self.next_obses.append(next_obs)
        self.actions.append(action)
        self.dones.append(done)
        self.ep_found_goals.append(ep_found_goal)

    def get(self):
        """Return the collected rollout data."""
        return {
            "obs": self.obses,
            "next_obs": self.next_obses,
            "actions": self.actions,
            "dones": self.dones,
            "ep_found_goals": self.ep_found_goals,
        }

    def clear(self):
        """Reset the rollout buffer."""
        self.obses = []
        self.next_obses = []
        self.actions = []
        self.dones = []
        self.ep_found_goals = []

def npify(data):
    for k in data:
        if k == 'terminals':
            dtype = np.bool_
        else:
            dtype = np.float32

        data[k] = np.array(data[k], dtype=dtype)


# def sample_env_and_controller(args):
#     # layout_str = maze_layouts.rand_layout(args.rand_maze_size)
#     # env = maze_model.MazeEnv(layout_str, agent_centric_view=args.agent_centric)
#     # env = gym.make(args.maze)
#     # env = GymWrapper(env=env, from_pixels=False, height=480, width=480, channels_first=True, frame_skip=1, return_state=False)
#     # env = DictWrapper(env, return_state=False)
#     env = make_envs(args)
#     try:
#         layout_str = env.env.env.env.str_maze_spec
#     except AttributeError:
#         layout_str = env.env.env.str_maze_spec

#     if args.mz_box_constrained:
#         action_ub = args.mz_ub
#         print(action_ub)
#     else:
#         action_ub = 1.0
#     controller = waypoint_controller.WaypointController(layout_str, action_ub=action_ub)
#     return env, controller

# Action directions
DIR_TO_VEC = [
    np.array((1, 0)),   # Right
    np.array((0, 1)),   # Down
    np.array((-1, 0)),  # Left
    np.array((0, -1)),  # Up
]

# Empty square [1, 0, 0, 0],
# Wall [0, 1, 0, 0],
# Goal [0, 0, 1, 0],
# Agent [0, 0, 0, 1],

def is_valid(pos, grid):
    x, y = pos
    H, W, C = grid.shape
    return 0 <= x < W and 0 <= y < H and not np.array_equal(grid[y, x], [0, 1, 0])  # not a wall

def find_position(state, channel):
    """Find the (x, y) position of the cell where state[channel] == 1."""
    pos = np.argwhere(state[channel] == 1)
    if len(pos) == 0:
        return None
    y, x = pos[0]
    return (x, y)

def dfs(state):
    """
    DFS search from agent to goal.
    Returns: a list of actions [0, 1, 2, ...] if path found, else [].
    """
    visited = set()
    path = []

    start = find_position(state, 3)  # Agent channel
    goal = find_position(state, 2)   # Goal channel
    grid = np.transpose(state[:3], (1, 2, 0))  # shape [19, 19, 3] – for checking cell type

    if start is None or goal is None:
        return []

    def dfs_recursive(pos):
        if pos == goal:
            return True

        visited.add(pos)

        for action, delta in enumerate(DIR_TO_VEC):
            next_pos = (pos[0] + delta[0], pos[1] + delta[1])
            if next_pos in visited or not is_valid(next_pos, grid):
                continue

            path.append(action)
            if dfs_recursive(next_pos):
                return True
            path.pop()  # backtrack

        return False

    found = dfs_recursive(start)
    return path if found else []



def reset_env(env, agent_centric=False):
    s = env.reset()  # return obs
    if agent_centric:
        [env.render(mode='rgb_array') for _ in range(100)]  # so that camera can catch up with agent
    return s

def resize_video(images, dim=64):
    """Resize a video in numpy array form to target dimension."""
    ret = np.zeros((images.shape[0], dim, dim, 3))

    for i in range(images.shape[0]):
        ret[i] = cv2.resize(images[i], dsize=(dim, dim),
                            interpolation=cv2.INTER_CUBIC)

    return ret.astype(np.uint8)
# # import skvideo.io
# def save_video(file_name, frames, fps=20, video_format='mp4', dim=256):
#     images = resize_video(np.array(frames), dim=dim)

#     # save video
#     fourcc = cv2.VideoWriter_fourcc(*'mp4v')
#     video = cv2.VideoWriter(file_name, fourcc, 20.0, (dim, dim))
#     for i in range(images.shape[0]):
#         video.write(images[i])
#     video.release()

def add_args(parser):
    parser.add_argument('--render', action='store_true', help='Render trajectories')
    parser.add_argument('--agent_centric', action='store_true', help='Whether agent-centric images are rendered.')
    parser.add_argument('--save_video', action='store_true', default=True, help='Whether rendered images are saved.')
    parser.add_argument('--log_dir', type=str, default='./scr/', help='Base directory for dataset')
    parser.add_argument('--seed', type=int, default=0, help='')

    # demo collection
    parser.add_argument('--min_traj_len', type=int, default=20, help='Min number of samples per trajectory')
    parser.add_argument('--num_trajs', type=int, default=400, help='Number of trajectories to collect')
    parser.add_argument('--only_save_success_trajs', type=str2bool, default=True, help='Only save successful trajectories')


    # env related
    parser.add_argument('--env_name', type=str, default='MiniGrid-FourRooms-Long-Hypotenuse-v0', help='Or MiniGrid-FourRooms-v0')
    parser.add_argument('--warp-frame', type=str2bool, default=False)
    parser.add_argument("--transpose-frame", type=str2bool, default=True)
    # parser.add_argument('--seed', type=int, default=0, help='')


    parser.add_argument('--dim-filter', type=float, default=1.0, help="how many percent of the dimensions to keep")


def get_default_args():
    parser = argparse.ArgumentParser()
    add_args(parser)
    args, rest = parser.parse_known_args()
    env_interface = get_env_interface(args.env_name)(args)
    env_parser = argparse.ArgumentParser()
    env_interface.get_add_args(env_parser)
    env_args, rest = env_parser.parse_known_args(rest)
    rutils.update_args(args, vars(env_args))
    return args


def main():
    args = get_default_args()

    # random create an id
    # date_id = f"{d.month}{d.day}{d.hour}{d.minute}{d.second}"
    unique_id = uuid.uuid4().hex[:8]  # Shortened version of UUID
    prefix = f"{args.seed}-{unique_id}-"
    logdirs = set_up_log_dirs(args, prefix)
    log_dir, wandb_dir, agent_save_dir, agent_best_dir, reward_save_dir, video_save_dir = logdirs

    os.makedirs(args.log_dir, exist_ok=True)

    # env, controller = sample_env_and_controller(args)
    env = make_envs(args)
    s = reset_env(env, agent_centric=args.agent_centric)

    obses, next_obses, actions, dones, ep_found_goals = [], [], [], [], []
    cnt = 0

    # for episode in range(args.num_trajs):
    episode = 0
    while episode < args.num_trajs:
        frame_buffer = []
        logging(f"Episode: {episode + 1}")
        rollout = Rollout()
        success = False
        s = reset_env(env, agent_centric=args.agent_centric)

        action_sequence = dfs(s)
        action_index = 0
        
        while True and action_index < len(action_sequence):
            # position = s[0:2]
            # velocity = s[2:4]
            # target = tuple(s[4:6])
            
            # try:
            #     act, done = controller.get_action(position, velocity, target)
            # except ValueError:
            #     logging("Failed to find valid path to goal. Resetting environment.")
            #     break

            # act = np.clip(act, -1.0, 1.0)
            act = action_sequence[action_index]
            ns, reward, done, info = env.step(act)
            action_index += 1
            frame_buffer.append(env.render(mode='rgb_array'))

            # success = success or info.get('goal_achieved', False)
            success = success or info.get('ep_found_goal', 0.0) == 1.0

            # # store the real action taken by the agent
            # if args.mz_safe_constrained:
            #     save_act = info['real_action']
            # else:
            #     save_act = act
            rollout.add(s, ns, act, done, success)

            if done:
                episode += 1
                if args.only_save_success_trajs:
                    if len(rollout.actions) > args.min_traj_len and success:
                        # add the trajectory to the main rollout
                        obses.extend(rollout.obses)
                        next_obses.extend(rollout.next_obses)
                        actions.extend(rollout.actions)
                        dones.extend(rollout.dones)
                        ep_found_goals.extend(rollout.ep_found_goals)
                        cnt += 1
                        logging(f"Trajectory {cnt} collected successfully.")
                    else:
                        episode -= 1    
                        logging("Failed to find the goal")
                else:
                    data = rollout.get()
                    obses.extend(data['obs'])
                    next_obses.extend(data['next_obs'])
                    actions.extend(data['actions'])
                    dones.extend(data['dones'])
                    ep_found_goals.extend(data['ep_found_goals'])
                    cnt += 1
                    if len(rollout.actions) > args.min_traj_len and success:
                        pass
                    else:
                        logging("Failed to find the goal")
                    logging(f"Trajectory {cnt} collected successfully.")
                rollout.clear()

                break
            s = ns
        
    # Save collected trajectories
    save_trajectories(obses, next_obses, actions, dones, ep_found_goals, args)

    if args.save_video:
        save_file = save_video(video_save_dir, np.array(frame_buffer), episode_id=0)

def save_demos(env, args, rollouts, num_trajs):
    dir_name = 'demos-%s-noisy' % args.maze if args.noisy else 'demos-%s' % args.maze
    if args.batch_idx >= 0:
        dir_name = os.path.join(dir_name, "batch_{}".format(args.batch_idx))
    file_name = os.path.join(args.log_dir, dir_name, "rollout_{}_trajs.pkl".format(num_trajs))

    new_rollouts = []
    for rollout in rollouts:
        new_rollout = {
            "obs": rollout["ob"],
            "actions": rollout["ac"],
            "dones": rollout["done"],
        }
        new_rollouts.append(new_rollout)

    logging("[*] Generating demo: {}".format(file_name))
    with open(file_name, "wb") as f:
        pickle.dump(new_rollouts, f)

def load_demos(file_name):
    # check saved pickle file
    num_demos = 0
    _data = []
    with open(file_name, "rb") as f:
        demos = pickle.load(f)
        # if not isinstance(demos, list):
        #     demos = [demos]
        # for demo in demos:
        #     if len(demo["obs"]) != len(demo["actions"]) + 1:
        #         logger.error(
        #             "Mismatch in # of observations (%d) and actions (%d) (%s)",
        #             len(demo["obs"]),
        #             len(demo["actions"]),
        #             file_name,
        #         )
        #         continue
    
        #     num_demos += 1
    
        #     length = len(demo["actions"])
        #     for i in range(length):
        #         transition = {
        #             "ob": demo["obs"][i],
        #             "ob_next": demo["obs"][i + 1],
        #         }
        #         if isinstance(demo["actions"][i], dict):
        #             transition["ac"] = demo["actions"][i]
        #         else:
        #             transition["ac"] = gym.spaces.unflatten(
        #                 env.action_space, demo["actions"][i]
        #             )
        #         if "rewards" in demo:
        #             transition["rew"] = demo["rewards"][i]
        #         else:
        #             transition["rew"] = 0.0
        #         if "dones" in demo:
        #             transition["done"] = int(demo["dones"][i])
        #         else:
        #             transition["done"] = 1 if i + 1 == length else 0
        #         _data.append(transition)
    logging("Load %d demonstrations with %d states", num_demos, len(_data))

def save_trajectories(obses, next_obses, actions, dones, ep_found_goals, args):
    """Save collected trajectories to a .pt file."""
    only_save_success_trajs = args.only_save_success_trajs
    osst_suffix = '_only_success' if only_save_success_trajs else ''
    save_path = os.path.join(args.log_dir, f'{args.env_name}_{args.num_trajs}{osst_suffix}.pt')
    data = {
        'obs': torch.tensor(np.array(obses)),
        'next_obs': torch.tensor(np.array(next_obses)),
        'actions': torch.tensor(np.array(actions).reshape(-1, 1)),
        'done': torch.tensor(np.array(dones)),
        'ep_found_goal': torch.tensor(np.array(ep_found_goals))
    }
    torch.save(data, save_path)
    logging(f"Trajectories saved at {save_path}")


if __name__ == "__main__":
    main()

    # # check h5 file
    # f = h5py.File("./maze2d-medium/rollout_0.h5", "r")
    # # Print all root level object names (aka keys)
    # # these can be group or dataset names
    # print("Keys: %s" % f.keys())
    # # get first object name/key; may or may NOT be a group
    # a_group_key = list(f.keys())[0]
    #
    # # get the object type for a_group_key: usually group or dataset
    # print(type(f[a_group_key]))
    #
    # # If a_group_key is a group name,
    # # this gets the object names in the group and returns as a list
    # data = list(f[a_group_key])
