#!/usr/bin/env python3
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
import os
import sys
import time
import pickle as pkl
import tqdm
import wandb
import datetime

# from logger import Logger
from replay_buffer import ReplayBuffer
from reward_model import RewardModel
from agent.sac import SACAgent
from utils.video_utils import save_video

import utils
import hydra

class Workspace(object):
    def __init__(self, cfg):
        self.work_dir = os.getcwd().replace(os.getcwd().split("exp")[-1], "")
        print(f'workspace: {self.work_dir}')

        self.cfg = cfg
        self.logger = None

        utils.set_seed_everywhere(cfg.seed)
        cfg.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        self.device = torch.device(cfg.device)
        self.log_success = False
        self.version = None

        # make env
        if 'metaworld' in cfg.env:
            self.env = utils.make_metaworld_env(cfg)
            self.log_success = True
        else:
            self.env = utils.make_env(cfg)

        cfg.agent.agent.params.obs_dim = self.env.observation_space.shape[0]
        cfg.agent.agent.params.action_dim = self.env.action_space.shape[0]
        cfg.agent.agent.params.action_range = [
            float(self.env.action_space.low.min()),
            float(self.env.action_space.high.max())
        ]
        cfg.agent.agent.params.critic_cfg.params.obs_dim = cfg.agent.agent.params.obs_dim
        cfg.agent.agent.params.critic_cfg.params.action_dim = cfg.agent.agent.params.action_dim
        cfg.agent.agent.params.actor_cfg.params.obs_dim = cfg.agent.agent.params.obs_dim
        cfg.agent.agent.params.actor_cfg.params.action_dim = cfg.agent.agent.params.action_dim

        config_dict = {
            'obs_dim': cfg.agent.agent.params.obs_dim,
            'action_dim': cfg.agent.agent.params.action_dim,
            'action_range': cfg.agent.agent.params.action_range,
            'device': self.device,
            'critic_cfg': cfg.agent.agent.params.critic_cfg,
            'actor_cfg': cfg.agent.agent.params.actor_cfg,
            'discount': cfg.agent.agent.params.discount,
            'init_temperature': cfg.agent.agent.params.init_temperature,
            'alpha_lr': cfg.agent.agent.params.alpha_lr,
            'alpha_betas': cfg.agent.agent.params.alpha_betas,
            'actor_lr': cfg.agent.agent.params.actor_lr,
            'actor_betas': cfg.agent.agent.params.actor_betas,
            'actor_update_frequency': cfg.agent.agent.params.actor_update_frequency,
            'critic_lr': cfg.agent.agent.params.critic_lr,
            'critic_betas': cfg.agent.agent.params.critic_betas,
            'critic_tau': cfg.agent.agent.params.critic_tau,
            'critic_target_update_frequency': cfg.agent.agent.params.critic_target_update_frequency,
            'batch_size': cfg.agent.agent.params.batch_size,
            'learnable_temperature': cfg.agent.agent.params.learnable_temperature,
            'wandb_use': cfg.wandb.use,
        }
        self.agent = SACAgent(config_dict)
        print(cfg.agent.agent.params.actor_cfg)

        self.replay_buffer = ReplayBuffer(
            self.env.observation_space.shape,
            self.env.action_space.shape,
            int(cfg.replay_buffer_capacity),
            self.device
        )
        meta_file = os.path.join(self.work_dir, 'metadata.pkl')
        pkl.dump({'cfg': self.cfg}, open(meta_file, "wb"))

        # for logging
        self.total_feedback = 0
        self.labeled_feedback = 0
        self.step = 0

        # instantiating the reward model
        self.reward_model = RewardModel(
            self.env.observation_space.shape[0],
            self.env.action_space.shape[0],
            ensemble_size=cfg.ensemble_size,
            size_segment=cfg.segment,
            activation=cfg.activation,
            lr=cfg.reward_lr,
            mb_size=cfg.reward_batch,
            large_batch=cfg.large_batch,
            label_margin=cfg.label_margin,
            teacher_beta=cfg.teacher_beta,
            teacher_gamma=cfg.teacher_gamma,
            teacher_eps_mistake=cfg.teacher_eps_mistake,
            teacher_eps_skip=cfg.teacher_eps_skip,
            teacher_eps_equal=cfg.teacher_eps_equal,
            reward_model=cfg.reward.model,
            sequential_num_samples=cfg.reward.sequential_num_samples,
            alpha = cfg.reward.alpha,
        )

    def evaluate(self):
        average_episode_reward = 0
        average_true_episode_reward = 0
        if self.log_success:
            success_rate = 0

        for episode in range(self.cfg.num_eval_episodes):
            obs = self.env.reset()
            self.agent.reset()
            done = False
            episode_reward = 0
            true_episode_reward = 0
            if self.log_success:
                episode_success = 0
            frames = []

            episode_step = 0
            success_timestep = 0
            while not done:
                with utils.eval_mode(self.agent):
                    action = self.agent.act(obs, sample=False)
                try:
                    obs, reward, done, truncated, extra = self.env.step(action)
                except:
                    obs, reward, done, extra = self.env.step(action)

                if self.log_success:
                    frame = self.env.render(mode='rgb_array', camera_name="corner3") # drawer: corner3 # window: corner
                    frames.append(frame)
                else:
                    width = 480
                    height = 480
                    # quadruped: "y", "hopper": 
                    frame = self.env.physics.render(height, width, camera_id=0)[:, :, ::-1] # 0
                    frames.append(frame)

                episode_step += 1
                done_no_max = 0 if episode_step + 1 == self.env._max_episode_steps else done
                if 'metaworld' in self.cfg.env and episode_step + 1 == self.env._max_episode_steps:
                    done = 1
                    episode_step = 0

                episode_reward += reward
                true_episode_reward += reward
                if self.log_success:
                    if episode_success == 0 and extra['success'] == 1:
                        success_timestep = episode_step
                    episode_success = max(episode_success, extra['success'])

            average_episode_reward += episode_reward
            average_true_episode_reward += true_episode_reward
            if self.log_success:
                success_rate += episode_success

            if not os.path.exists(os.path.join(self.work_dir, "evaluation/videos/{}".format(self.cfg.env))):
                os.makedirs(os.path.join(self.work_dir, "evaluation/videos/{}".format(self.cfg.env)))
            
            if not self.cfg.real_human.use:
                save_dir = os.path.join(self.work_dir, "evaluation/videos")
            else:
                save_dir = os.path.join(self.work_dir, "evaluation/videos/real_human/{}".format(self.cfg.real_human.name))
            if not os.path.exists(os.path.join(save_dir, self.cfg.env)):
                os.makedirs(os.path.join(save_dir, self.cfg.env))
            if self.log_success:
                save_video(frames, filename=os.path.join(save_dir, "{}/{}_{}_success_{}_reward_{:.1f}.mp4".format(self.cfg.env, self.version, episode, episode_success, episode_reward)))
                print("saved {} video | success: {} | reward: {} | success at {} timestep".format(episode, episode_success, episode_reward, success_timestep))
            else:
                # print reward
                save_video(frames, filename=os.path.join(save_dir, "{}/{}_{}_reward_{:.1f}.mp4".format(self.cfg.env, self.version, episode, episode_reward)))
                print("saved {} video | reward: {:.1f}".format(episode, episode_reward))

        average_episode_reward /= self.cfg.num_eval_episodes
        average_true_episode_reward /= self.cfg.num_eval_episodes
        if self.log_success:
            success_rate /= self.cfg.num_eval_episodes
            success_rate *= 100.0

        print("env: {} | version: {}".format(self.cfg.env, self.version))
        print("eval/episode_reward: ", average_episode_reward)
        if self.log_success:
            print("eval/success_rate: ", success_rate)

    def run(self):
        if "pairwise" in self.cfg.wandb.version:
            version = "pairwise"
        elif "sequential" in self.cfg.wandb.version:
            version = "sequential"
        elif "root" in self.cfg.wandb.version:
            version = "root"
        self.version = version
        print("version: ", version)
        if not self.cfg.real_human.use:
            load_dir = os.path.join(self.work_dir, "evaluation/checkpoints/{}/{}".format(self.cfg.env, version))
        else:
            load_dir = os.path.join(self.work_dir, "evaluation/checkpoints/real_human/{}/{}/{}".format(self.cfg.real_human.name, self.cfg.env, version))

        self.step = int(self.cfg.num_train_steps)
        self.agent.load(load_dir, self.step)
        self.reward_model.load(load_dir, self.step)
        self.evaluate()


@hydra.main(config_path="config", config_name="train_SeqRank")
def main(cfg):
    # wandb initialize
    if 'debug' in cfg.wandb.tag:
        datetime_now = str(datetime.datetime.today()).split(".")[0].replace(" ","_")
        cfg.wandb.tag = datetime_now + "_" + cfg.wandb.tag
    print(cfg.wandb.version)
    print(type(cfg.wandb.version))
    cfg.wandb.version += "_" + cfg.wandb.tag
    if cfg.wandb.use:
        wandb.init(project="{}".format(cfg.env), entity=cfg.wandb.username, reinit=True)
        wandb.run.name = cfg.wandb.version
        wandb.config.update(cfg)
        
    workspace = Workspace(cfg)
    workspace.run()


if __name__ == '__main__':
    main()
