"""
Copyright 2022 Div Garg. All rights reserved.

Example training code for IQ-Learn which minimially modifies `train_rl.py`.
"""

import datetime
import os
import random
import time
from collections import deque
from itertools import count
import types

import hydra
import numpy as np
import torch
import torch.nn.functional as F
import wandb
from omegaconf import DictConfig, OmegaConf
from tensorboardX import SummaryWriter

from make_reward_models import make_reward_model
from wrappers.atari_wrapper import LazyFrames
from make_envs import make_env
from dataset.memory import Memory
from agent import make_agent
from utils.utils import eval_mode, split_expert_memory, get_irl_reward, get_args, gen_frame, save_video, set_up_log_dirs
# from utils.logger import Logger
from utils.wandb_logger import wandb_logger as Logger
from iq import iq_loss

import uuid

torch.set_num_threads(2)




def train_reward_model(batch_size, expert_memory_replay_train, reward_model, agent, device, GAMMA, logger, step, args):
    expert_batch = expert_memory_replay_train.get_samples(batch_size, device)
    # reward_loss = reward_model.update(agent, expert_batch)
    # if step % args.reward_gen.train.log_interval == 0:
    #     logger.log('train/reward_loss', reward_loss, step)
    #     print(f'--> Train reward loss: {reward_loss}, step: {step}')
    return_info = reward_model.update(agent, expert_batch)
    if step % args.reward_gen.train.log_interval == 0:
        log_dict = {}
        for key, value in return_info.items():
            # logger.log(f'train/{key}', value, step)
            log_dict[f'train/{key}'] = value
            print(f'--> Train {key}: {value}, step: {step}')
        logger.wandb_log(log_dict, step)

def eval_reward_model(batch_size, expert_memory_replay_val, reward_model, agent, device, GAMMA, logger, step, args):
    expert_batch = expert_memory_replay_val.get_samples(batch_size, device)
    expert_obs, expert_next_obs, expert_action, expert_reward, expert_done = expert_batch
    # get target reward from agent
    with torch.no_grad():
        irl_reward = get_irl_reward(agent, expert_obs, expert_next_obs, expert_action, expert_done, GAMMA)
        # train reward model
        predicted_reward = reward_model(expert_obs, expert_action, expert_next_obs, expert_done)
        reward_loss = F.mse_loss(predicted_reward, irl_reward)
    # logger.log('eval/reward_loss', reward_loss, step)
    log_dict = {}
    log_dict['eval/reward_loss'] = reward_loss
    logger.wandb_log(log_dict, step)
    print(f'--> Eval reward loss: {reward_loss}, step: {step}')

def eval_gen_video(agent, env, reward_model, GAMMA, device, logger, step, video_save_dir):
    # use agent to step through the environment, and compare the rewards
    state = env.reset()
    done = False
    frames = []
    while not done:
        action = agent.choose_action(state, sample=False)
        next_state, reward, done, _ = env.step(action)

        with torch.no_grad():
            expert_obs = state[None, :]
            expert_next_obs = next_state[None, :]
            expert_action = np.array([action]).reshape(1, -1)
            expert_done = np.array([done]).reshape(1, -1)

            # transform all into torch tensor
            expert_obs = torch.tensor(expert_obs, dtype=torch.float32, device=device)
            expert_next_obs = torch.tensor(expert_next_obs, dtype=torch.float32, device=device)
            expert_action = torch.tensor(expert_action, dtype=torch.float32, device=device)
            expert_done = torch.tensor(expert_done, dtype=torch.float32, device=device)

            irl_reward = get_irl_reward(agent, expert_obs, expert_next_obs, expert_action, expert_done, GAMMA, device=device)
            irl_reward = irl_reward.item()
            # train reward model
            predicted_reward = reward_model(expert_obs, expert_action, expert_next_obs, expert_done)
            predicted_reward = predicted_reward.item()
        # get frame
        frames.append(gen_frame(frame=env.render(mode='rgb_array'), irl_reward=irl_reward, learned_reward=predicted_reward))
        state = next_state
    
    # # save frames as video
    # import cv2
    # width, height, channels = frames[0].shape
    # frames_len = len(frames)

    # # Define the codec and create a VideoWriter object
    # fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # 'mp4v' for .mp4
    # out = cv2.VideoWriter(video_path, fourcc, 20.0, (width, height))

    # # Loop over each frame in the video
    # for i in range(frames_len):
    #     frame = frames[i]
    #     out.write(frame)

    # # Release the VideoWriter
    # out.release()
    save_video(video_save_dir, frames, episode_id=step)

    # log the video
    frames = np.array(frames)
    frames = frames.transpose(0, 3, 1, 2)
    # logger.log_video('eval/videos', frames, step) ## TODO: log video
    

@hydra.main(config_path="conf", config_name="config_reward_gen")
def main(cfg: DictConfig):
    args = get_args(cfg)
    config_dict = OmegaConf.to_container(args, resolve=True)

    # create a unique prefix for the run
    # unique_id = uuid.uuid4().hex[:8]  # Shortened version of UUID
    # run_id = f"{args.seed}_train_reward_{unique_id}"

    logger = Logger(args)
    logdirs = set_up_log_dirs(args, logger.prefix)
    log_dir, wandb_dir, agent_save_dir, agent_best_dir, reward_save_dir, video_save_dir = logdirs
    logger._create_wandb(log_dir=wandb_dir)

    # set seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device(args.device)
    if device.type == 'cuda' and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    env_args = args.env
    env = make_env(args)
    eval_env = make_env(args)

    # Seed envs
    env.seed(args.seed)
    eval_env.seed(args.seed + 10)
    # set action space seed
    if hasattr(env.action_space, 'seed'):
        env.action_space.seed(args.seed)
        eval_env.action_space.seed(args.seed + 10)

    REPLAY_MEMORY = int(env_args.replay_mem) // 2
    if args.reward_gen.train.add_online_data:
        REPLAY_MEMORY += args.reward_gen.train.online_data_size

    agent = make_agent(env.observation_space, env.action_space, args, load_agent_path=args.pretrain)
    GAMMA = args.gamma

    # Load expert data
    expert_memory_replay = Memory(REPLAY_MEMORY, args.seed)
    expert_memory_replay.load(hydra.utils.to_absolute_path(f'~/projects/iq_learn/expert_dataset/iq_learn/experts/{args.env.demo}'),
                              num_trajs=args.expert.demos,
                              sample_freq=args.expert.subsample_freq,
                              seed=args.seed + 42)
    print(f'--> Expert memory size: {expert_memory_replay.size()}')

    # add more data to expert memory by running the expert agent
    if args.reward_gen.train.add_online_data:
        data_index = 0
        done = False
        state = env.reset()
        rand_threshold = args.reward_gen.train.online_rand_threshold
        while data_index < args.reward_gen.train.online_data_size:
            action = agent.choose_action(state, sample=False)
            if random.random() < rand_threshold:
                action = env.action_space.sample()
            next_state, reward, done, _ = env.step(action)
            data_index += 1
            data = (state, next_state, action, reward, done)
            expert_memory_replay.add(data)
            state = next_state
            if done:
                state = env.reset()
            if data_index % 1000 == 0:
                print(f'--> added {data_index / args.reward_gen.train.online_data_size * 100}% of online data')
        print(f'--> Expert memory size after adding online data: {expert_memory_replay.size()}')
            

            

    # split expert memory into two parts, one for training reward model, one for validation
    expert_memory_replay_train, expert_memory_replay_val = split_expert_memory(expert_memory_replay, 0.8, args)


    # logdirs = set_up_log_dirs(args, run_id)
    # log_dir, wandb_dir, agent_save_dir, agent_best_dir, reward_save_dir, video_save_dir = logdirs

    # writer = SummaryWriter(log_dir=wandb_dir)
    # logger = Logger(wandb_dir,
    #                 log_frequency=args.log_interval,
    #                 writer=writer,
    #                 save_tb=True,
    #                 agent=args.agent.name)

    steps = 0

    # init a reward model, R(s, s')
    reward_model = make_reward_model(env.observation_space, env.action_space, device=device, args=args, load_pretrained_model=False)
    batch_size = args.reward_gen.train.batch


    for steps in range(args.reward_gen.train.learn_steps):
        # sample expert batch
        train_reward_model(batch_size=batch_size, 
                           expert_memory_replay_train=expert_memory_replay_train, 
                           reward_model=reward_model, 
                           agent=agent, 
                           device=device, 
                           GAMMA=GAMMA, 
                           logger=logger, 
                           step=steps, 
                           args=args)
        
        if steps % args.reward_gen.train.eval_interval == 0:
            eval_reward_model(batch_size=batch_size, 
                              expert_memory_replay_val=expert_memory_replay_val, 
                              reward_model=reward_model, 
                              agent=agent, 
                              device=device, 
                              GAMMA=GAMMA, 
                              logger=logger, 
                              step=steps, 
                              args=args)
        
        if steps % args.reward_gen.train.video_eval_interval == 0:
            eval_gen_video(agent=agent, 
                           env=eval_env, 
                           reward_model=reward_model, 
                           GAMMA=GAMMA, 
                           device=device, 
                           logger=logger,
                           step=steps,
                           video_save_dir=video_save_dir
                           )
        
        save(reward_model, steps, args, output_dir=reward_save_dir)


def save(reward_model, epoch, args, output_dir='results'):
    if epoch % args.reward_gen.train.save_interval == 0:
        
        name = f'{args.reward_gen.model.type}_{args.env.name}'

        if not os.path.exists(output_dir):
            os.mkdir(output_dir)
        model_path = f'{output_dir}/{args.agent.name}_{name}_{epoch}'
        reward_model.save(model_path)
        print(f'--> Saved model at {epoch} steps, path:{model_path}')


NODE_TO_ONE_HOT = {
    # Empty square
    (1, 0, 0, 0): [0, 0, 0], # black
    # Wall
    (0, 1, 0, 0): [100, 100, 100], # gray
    # Goal
    (0, 0, 1, 0): [0, 255, 0],
    # Agent
    (0, 0, 0, 1): [255, 0, 0],
}


def grid_img_gen(grids_th, name):
    # grid is a numpy array of shape(4, 19, 19)
    # firstly revert it into color image
    grids = grids_th.cpu().numpy()
    grids = np.transpose(grids, (1, 2, 0))
    img = []
    for i in range(grids.shape[0]):
        for j in range(grids.shape[1]):
            grid = grids[i, j]
            img.append(NODE_TO_ONE_HOT[tuple(grid.astype(int))])
    # to make the image more clear, we need to resize the image
    img = np.array(img).astype(np.uint8)
    img = img.reshape(19, 19, 3)
    from PIL import Image
    img = Image.fromarray(img)
    img_resized = img.resize((190, 190), Image.NEAREST)
    img_resized.save(f'~/projects/goal_prox_il/iq_learn/debug_img/grid_{name}.jpg')


if __name__ == "__main__":
    main()
