import datetime
import os
import random
import time
from collections import deque
from itertools import count
import types
import uuid

import hydra
import numpy as np
import torch
import torch.nn.functional as F
import wandb
from omegaconf import DictConfig

from iq_learn.utils.wandb_logger import wandb_logger as Logger
from iq_learn.make_envs import make_env
from iq_learn.dataset.memory import Memory
from iq_learn.agent import make_agent
from iq_learn.utils.utils import evaluate, eval_mode, gen_frame, get_args, save_video, set_up_log_dirs, logging

torch.set_num_threads(2)

import gym
import heapq
import numpy as np

def dijkstra_minigrid(env):
    step_cost=env.step_cost
    trap_penalty=env.trap_penalty
    goal_reward=env.goal_reward

    grid_width, grid_height = env.grid.width, env.grid.height
    grid_size = (env.grid.width, env.grid.height)

    start_pos = tuple(env.agent_pos)
    trap_pos = env.trap_pos
    # find goal_pos
    for i in range(grid_width):
        for j in range(grid_height):
            if env.grid.get(i, j) is not None and env.grid.get(i, j).type == 'goal':
                goal_pos = (i, j)
                break

    logging(f"Start position: {start_pos}")
    logging(f"Trap position: {trap_pos}")
    logging(f"Goal position: {goal_pos}")

    visitable_grid = np.zeros((grid_width, grid_height))
    # filter the grids that are not visitable
    for i in range(grid_width):
        for j in range(grid_height):
            if env.grid.get(i, j) is None or env.grid.get(i, j).type != 'wall':
                visitable_grid[i, j] = 1

    # Dijkstra's structures
    rewards = {start_pos: 0}  # Store maximum reward for each position
    found_goal = False
    paths = {start_pos: []}  # Store the corresponding path
    visited = set()

    # Initialize priority queue
    pq = []
    # Priority queue stores (-reward, current_position)
    heapq.heappush(pq, (0, start_pos))  # Reward is negative for heapq (maximization)

    # Define movement deltas for forward, left, right, etc.
    action_deltas = {
        0: (1, 0),   # Move forward
        1: (0, 1),   # Move right
        2: (-1, 0),  # Move backward
        3: (0, -1),  # Move left
    }

    while pq:
        # Pop the node with the highest reward
        current_neg_reward, current_pos = heapq.heappop(pq)
        current_reward = -current_neg_reward

        # Mark as visited
        if current_pos in visited:
            continue
        visited.add(current_pos)

        # Check if the goal is reached
        if current_pos == goal_pos:
            current_reward += goal_reward
            print(f"Goal reached with reward: {current_reward}")
            return paths[current_pos], current_reward, True

        # Explore neighbors
        for action, delta in action_deltas.items():
            next_pos = (current_pos[0] + delta[0], current_pos[1] + delta[1])

            # Check grid boundaries
            if not visitable_grid[next_pos[0]][next_pos[1]]:
                continue

            # Calculate new reward
            new_reward = current_reward + step_cost
            if next_pos == trap_pos:
                new_reward += trap_penalty
            # if next_pos == goal_pos:
            #     new_reward += goal_reward

            # Update rewards and paths if this path is better
            if next_pos not in rewards or new_reward < rewards[next_pos]:
                rewards[next_pos] = new_reward
                paths[next_pos] = paths[current_pos] + [action]
                heapq.heappush(pq, (-new_reward, next_pos))

    # If no path to the goal is found
    print("No path to the goal found.")
    return [], -float('inf'), False

@hydra.main(config_path="../conf", config_name="expert_gen/minigrid_deceptive")
def main(cfg: DictConfig):
    args = get_args(cfg)
    # set args.pretrain to None, because now we are training from scratch
    args.pretrain = None

    logger = Logger(args)
    logdirs = set_up_log_dirs(args, logger.prefix)
    log_dir, wandb_dir, agent_save_dir, agent_best_dir, reward_save_dir, video_save_dir = logdirs
    logger._create_wandb(log_dir=wandb_dir)

    # set seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    print(f"Seed set to {args.seed}")

    device = torch.device(args.device)                  
    if device.type == 'cuda' and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    # assert 'MiniGrid-Deceptive' in args.env.name, "This script is only for MiniGrid-Deceptive environment"
    env = make_env(args)
    env.seed(args.seed)

    # save the first 10 episodes frames as video
    frame_buffer = []
    video_ep = 50

    # save all trajectories with tensor
    obses = []
    next_obses = []
    dones = []
    actions = []
    ep_found_goals = []

    for episode in range(args.expert_gen.num_episodes):
        logging(f"Episode: {episode}")
        obs = env.reset()
        path, ep_reward, ep_found_goal = dijkstra_minigrid(env)
        # ep_found_goals.append(ep_found_goal)
        agent_pos = tuple(env.agent_pos)
        # print(agent_pos)
        if episode < video_ep:
            frame_buffer.append(env.render('rgb_array'))
        while True:
            if len(path) == 0:
                break
            action = path.pop(0)
            next_obs, reward, done, info = env.step(action)

            # print(env.agent_pos)
            if episode < video_ep:
                    frame_buffer.append(gen_frame(env.render('rgb_array'), true_reward=reward))

            obses.append(obs)
            next_obses.append(next_obs)
            dones.append(done)
            ep_found_goals.append(done)
            actions.append(action)

            if done:
                break
            obs = next_obs
    
    # save video
    video_save_path = save_video(video_save_dir, np.array(frame_buffer), episode_id=0)
    logging(f"Video saved at {video_save_path}")

    # save trajectories
    weights = {}
    weights['obs'] = torch.tensor(np.array(obses))
    weights['next_obs'] = torch.tensor(np.array(next_obses))
    weights['done'] = torch.tensor(np.array(dones))
    weights['actions'] = torch.tensor(np.array(actions).reshape(-1, 1))
    weights['ep_found_goal'] = torch.tensor(np.array(ep_found_goals))
    # save weights as pt
    save_path = os.path.join(reward_save_dir, f'{args.env.name}_{args.expert_gen.num_episodes}.pt')
    torch.save(weights, save_path)
    logging(f"Trajectories saved at {save_path}")



if __name__ == "__main__":
    main()
