import datetime
import os
import random
import time
from collections import deque
from itertools import count
import types
import uuid

import hydra
import numpy as np
import torch
import torch.nn.functional as F
import wandb
from omegaconf import DictConfig, OmegaConf
from tensorboardX import SummaryWriter

from make_reward_models import make_reward_model
from utils.logger import Logger
from make_envs import make_env
from dataset.memory import Memory
from agent import make_agent
from utils.utils import evaluate, eval_mode, get_args
from rlf.exp_mgr import config_mgr


from agent.ppo_utils.rollouts import RolloutRunner

torch.set_num_threads(2)



def get_predicted_reward(reward_model, state, next_state, action, done, device):
    # Predict reward
    with torch.no_grad():
        expert_obs = state[None, :]
        expert_next_obs = next_state[None, :]
        expert_action = np.array([action]).reshape(1, -1)
        expert_done = np.array([done]).reshape(1, -1)

        # transform all into torch tensor
        expert_obs = torch.tensor(expert_obs, dtype=torch.float32, device=device)
        expert_next_obs = torch.tensor(expert_next_obs, dtype=torch.float32, device=device)
        expert_action = torch.tensor(expert_action, dtype=torch.float32, device=device)
        expert_done = torch.tensor(expert_done, dtype=torch.float32, device=device)
        
        predicted_reward = reward_model(expert_obs, expert_action, expert_next_obs, expert_done)
        predicted_reward = predicted_reward.item()
    return predicted_reward

@hydra.main(config_path="conf", config_name="config_reward_gen")
def main(cfg: DictConfig):
    args = get_args(cfg)
    config_dict = OmegaConf.to_container(args, resolve=True)

    # create a unique prefix for the run
    unique_id = uuid.uuid4().hex[:8]  # Shortened version of UUID
    run_id = f"{args.seed}_test_reward_{unique_id}"

    if not args.if_debug:
        config_mgr.init("./config.yaml")
        wb_proj_name = config_mgr.get_prop("proj_name")
        wb_entity = config_mgr.get_prop("wb_entity")
        wandb.init(project=wb_proj_name, entity=wb_entity, name=run_id,
                sync_tensorboard=True, reinit=True, config=config_dict)
    # set seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device(args.device)
    if device.type == 'cuda' and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    env_args = args.env
    env = make_env(args)
    eval_env = make_env(args)
    # Seed envs
    env.seed(args.seed)
    eval_env.seed(args.seed + 10)

    REPLAY_MEMORY = int(env_args.replay_mem)
    INITIAL_MEMORY = int(env_args.initial_mem)
    EPISODE_STEPS = int(env_args.eps_steps)
    EPISODE_WINDOW = int(env_args.eps_window)
    LEARN_STEPS = int(env_args.learn_steps)

    agent = make_agent(env.observation_space, env.action_space, args, load_agent_path=None)

    # load reward model
    reward_model = make_reward_model(ob_space=env.observation_space, action_space=env.action_space, device=device, args=args, load_pretrained_model=True)

    memory_replay = Memory(REPLAY_MEMORY, args.seed)


    # Setup logging
    def mkdir(path):
        if not os.path.exists(path):
            os.makedirs(path)
    log_dir = os.path.join(args.log_dir, run_id)

    wandb_dir = os.path.join(log_dir, 'wandb')
    # setup agent log dir
    agent_save_dir = os.path.join(log_dir, 'agent_model')
    agent_best_dir = os.path.join(log_dir, 'results_best')

    # Setup logging
    mkdir(args.log_dir)
    mkdir(log_dir)
    mkdir(wandb_dir)
    mkdir(agent_save_dir)
    mkdir(agent_best_dir)

    writer = SummaryWriter(log_dir=wandb_dir)
    print(f'--> Saving logs at: {wandb_dir}')
    logger = Logger(wandb_dir,
                    log_frequency=args.log_interval,
                    writer=writer,
                    save_tb=True,
                    agent=args.agent.name)
    steps = 0
    if args.env.name == 'MiniGrid-FourRooms-v0':
        flag_gw = True
    else:
        flag_gw = False
    learn_steps = 0
    begin_learn = False
    

    # track mean reward and scores
    rewards_window = deque(maxlen=EPISODE_WINDOW)  # last N rewards
    best_eval_returns = -np.inf

    
    for epoch in count():
        state = env.reset()
        episode_reward = 0
        episode_irl_reward = 0
        done = False

        start_time = time.time()
        for episode_step in range(EPISODE_STEPS):

            if steps < args.num_seed_steps:
                # Seed replay buffer with random actions
                action = env.action_space.sample()
            else:
                with eval_mode(agent):
                    action = agent.choose_action(state, sample=True)
            next_state, reward, done, _ = env.step(action)
            episode_reward += reward
            irl_reward = get_predicted_reward(reward_model, state, next_state, action, done, device)
            episode_irl_reward += irl_reward
            steps += 1
            if flag_gw and done:
                if reward > 0:
                    found_goal = 1
                else:
                    found_goal = 0

            if learn_steps % args.env.eval_interval == 0:
                eval_returns, eval_timesteps, eval_found_goals = evaluate(agent, eval_env, num_episodes=args.eval.eps, flag_gw=flag_gw)
                returns = np.mean(eval_returns)
                learn_steps += 1  # To prevent repeated eval at timestep 0
                logger.log('eval/episode_reward', returns, learn_steps)
                logger.log('eval/episode', epoch, learn_steps)
                if flag_gw:
                    found_goal_rate = np.mean(eval_found_goals)
                    logger.log('eval/episode_found_goal', found_goal_rate, learn_steps)
                logger.dump(learn_steps)
                # print('EVAL\tEp {}\tAverage reward: {:.2f}\t'.format(epoch, returns))

                if returns > best_eval_returns:
                    # Store best eval returns
                    best_eval_returns = returns
                    if not args.if_debug:
                        wandb.run.summary["best_returns"] = best_eval_returns
                    save(agent, epoch, args, output_dir='results_best')

            # only store done true when episode finishes without hitting timelimit (allow infinite bootstrap)
            done_no_lim = done
            if str(env.__class__.__name__).find('TimeLimit') >= 0 and episode_step + 1 == env._max_episode_steps:
                done_no_lim = 0
            memory_replay.add((state, next_state, action, irl_reward, done_no_lim))

            if memory_replay.size() > INITIAL_MEMORY:
                # Start learning
                if begin_learn is False:
                    print('Learn begins!')
                    begin_learn = True

                learn_steps += 1
                if learn_steps == LEARN_STEPS:
                    print('Finished!')
                    wandb.finish()
                    return

                losses = agent.update(memory_replay, logger, learn_steps)

                if learn_steps % args.log_interval == 0:
                    for key, loss in losses.items():
                        writer.add_scalar(key, loss, global_step=learn_steps)

            if done:
                break
            state = next_state
        
        rewards_window.append(episode_reward)
        logger.log('train/episode', epoch, learn_steps)
        logger.log('train/episode_reward', episode_reward, learn_steps)
        logger.log('train/episode_irl_reward', episode_irl_reward, learn_steps)
        logger.log('train/duration', time.time() - start_time, learn_steps)
        if flag_gw:
            logger.log('train/episode_found_goal', found_goal, learn_steps)
        logger.dump(learn_steps, save=begin_learn)
        # print('TRAIN\tEp {}\tAverage reward: {:.2f}\t'.format(epoch, np.mean(rewards_window)))
        save(agent, epoch, args, output_dir=agent_save_dir)



def save(agent, epoch, args, output_dir='results'):
    if epoch % args.save_interval == 0:
        if args.method.type == "sqil":
            name = f'sqil_{args.env.name}'
        else:
            name = f'iq_{args.env.name}'

        if not os.path.exists(output_dir):
            os.mkdir(output_dir)
        agent.save(f'{output_dir}/{args.agent.name}_{name}_{epoch}')


if __name__ == "__main__":
    main()
