import datetime
import os
import random
import time
from collections import deque
from itertools import count
import types
import uuid

import hydra
import numpy as np
import torch
import torch.nn.functional as F
import wandb
from omegaconf import DictConfig
# from tensorboardX import SummaryWriter

# from utils.logger import Logger
from utils.wandb_logger import wandb_logger as Logger
from make_envs import make_env
from dataset.memory import Memory
from agent import make_agent
from utils.utils import evaluate, eval_mode, get_args, set_up_log_dirs, logging

torch.set_num_threads(2)



@hydra.main(config_path="conf", config_name="config_reward_gen")
def main(cfg: DictConfig):
    args = get_args(cfg)
    # set args.pretrain to None, because now we are training from scratch
    args.pretrain = None

    logger = Logger(args)
    logdirs = set_up_log_dirs(args, logger.prefix)
    log_dir, wandb_dir, agent_save_dir, agent_best_dir, reward_save_dir, video_save_dir = logdirs
    logger._create_wandb(log_dir=wandb_dir)

    # # create a unique prefix for the run
    # unique_id = uuid.uuid4().hex[:8]  # Shortened version of UUID
    # run_id = f"{args.seed}_train_reward_{unique_id}"


    # set seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device(args.device)
    if device.type == 'cuda' and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    env_args = args.env
    env = make_env(args)
    eval_env = make_env(args)
    # Seed envs
    env.seed(args.seed)
    eval_env.seed(args.seed + 10)

    REPLAY_MEMORY = int(env_args.replay_mem)
    INITIAL_MEMORY = int(env_args.initial_mem)
    EPISODE_STEPS = int(env_args.eps_steps)
    EPISODE_WINDOW = int(env_args.eps_window)
    LEARN_STEPS = int(env_args.learn_steps)

    agent = make_agent(env.observation_space, env.action_space, args, load_agent_path=args.pretrain)

    memory_replay = Memory(REPLAY_MEMORY, args.seed)

    
    # writer = SummaryWriter(log_dir=log_dir)
    # logger = Logger(args.log_dir,
    #                 log_frequency=args.log_interval,
    #                 writer=writer,
    #                 save_tb=True,
    #                 agent=args.agent.name)
    steps = 0
    if args.env.name.find('MiniGrid') >= 0:
        flag_gw = True
    else:
        flag_gw = False
    learn_steps = 0
    begin_learn = False

    # track mean reward and scores
    rewards_window = deque(maxlen=EPISODE_WINDOW)  # last N rewards
    best_eval_returns = -np.inf

    for epoch in count():
        state = env.reset()
        episode_reward = 0
        done = False

        start_time = time.time()
        for episode_step in range(EPISODE_STEPS):

            if steps < args.num_seed_steps:
                # Seed replay buffer with random actions
                action = env.action_space.sample()
            else:
                with eval_mode(agent):
                    action = agent.choose_action(state, sample=True)
            next_state, reward, done, _ = env.step(action)
            episode_reward += reward
            steps += 1

            if learn_steps % args.env.eval_interval == 0:
                eval_returns, eval_timesteps, eval_found_goals = evaluate(agent, 
                                                                          eval_env, 
                                                                          args,
                                                                          logger=logger,
                                                                          num_episodes=args.eval.eps, 
                                                                          flag_gw=flag_gw,
                                                                          video_save_dir=video_save_dir,
                                                                          learn_steps=learn_steps)
                
                logging('EVAL\tEp {}\tAverage reward: {:.2f}\t'.format(epoch, np.mean(eval_returns)))
                learn_steps += 1  # To prevent repeated eval at timestep 0
                
                # print('EVAL\tEp {}\tAverage reward: {:.2f}\t'.format(epoch, returns))

                # if np.mean(eval_returns) > best_eval_returns:
                #     # Store best eval returns
                #     best_eval_returns = np.mean(eval_returns)
                #     wandb.run.summary["best_returns"] = best_eval_returns
                #     save(agent, epoch, args, output_dir='results_best')

            # only store done true when episode finishes without hitting timelimit (allow infinite bootstrap)
            done_no_lim = done
            if str(env.__class__.__name__).find('TimeLimit') >= 0 and episode_step + 1 == env._max_episode_steps:
                done_no_lim = 0
            memory_replay.add((state, next_state, action, reward, done_no_lim))

            if memory_replay.size() > INITIAL_MEMORY:
                # Start learning
                if begin_learn is False:
                    logging('Learn begins!')
                    begin_learn = True

                learn_steps += 1
                if learn_steps == LEARN_STEPS:
                    logging('Finished!')
                    wandb.finish()
                    return

                losses = agent.update(memory_replay, logger, learn_steps)
                if learn_steps % args.log_interval == 0:
                    logging('TRAIN\tEp {}\tAverage reward: {:.2f}\t'.format(epoch, np.mean(rewards_window)))

                # if learn_steps % args.log_interval == 0:
                    # for key, loss in losses.items():
                    #     writer.add_scalar(key, loss, global_step=learn_steps)
                    # for key, loss in losses.items():
                    #     logger.log(key, loss, learn_steps)

            if done:
                break
            state = next_state

        rewards_window.append(episode_reward)
        log_dict = {}
        log_dict['train/episode'] = epoch
        log_dict['train/episode_reward'] = episode_reward
        log_dict['train/duration'] = time.time() - start_time
        logger.wandb_log(log_dict, learn_steps)
        # logger.log('train/episode', epoch, learn_steps)
        # logger.log('train/episode_reward', episode_reward, learn_steps)
        # logger.log('train/duration', time.time() - start_time, learn_steps)
        # logger.dump(learn_steps, save=begin_learn)
        # print('TRAIN\tEp {}\tAverage reward: {:.2f}\t'.format(epoch, np.mean(rewards_window)))
        save(agent, epoch, args, output_dir='results')


def save(agent, epoch, args, output_dir='results'):
    if epoch % args.save_interval == 0:
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)
        agent.save(f'{output_dir}/{args.agent.name}_{args.env.name}')
        logging(f'Saved model at {output_dir}/{args.agent.name}_{args.env.name}')


if __name__ == "__main__":
    main()
