import datetime
import cv2
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import make_grid, save_image

from iq_learn.dataset.memory import Memory
import os
from PIL import Image
from omegaconf import DictConfig, OmegaConf

import imageio
imageio.plugins.ffmpeg.download = lambda: None  # to disable auto-download in older versions
# or specify the path
import os
os.environ["IMAGEIO_FFMPEG_EXE"] = "/usr/bin/ffmpeg"
import moviepy.editor as mpy


def get_args(cfg: DictConfig):
    cfg.device = "cuda" if torch.cuda.is_available() else "cpu"
    cfg.hydra_base_dir = os.getcwd()
    if cfg.exp.gamma_scale:
        pretrain_dict = { # add model path
        }
        cfg.pretrain = pretrain_dict[cfg.gamma]
    logging(OmegaConf.to_yaml(cfg))
    return cfg

def set_up_log_dirs(args, run_id):
    # Setup logging
    def mkdir(path):
        if not os.path.exists(path):
            os.makedirs(path)
    log_dir = os.path.join(args.log_dir, run_id)

    wandb_dir = os.path.join(log_dir, 'wandb')
    agent_save_dir = os.path.join(log_dir, 'agent_model')
    agent_best_dir = os.path.join(log_dir, 'results_best')
    reward_save_dir = os.path.join(log_dir, 'reward_model')
    video_save_dir = os.path.join(log_dir, 'videos')

    mkdir(args.log_dir)
    mkdir(log_dir)
    mkdir(wandb_dir)
    mkdir(agent_save_dir)
    mkdir(agent_best_dir)
    mkdir(reward_save_dir)
    mkdir(video_save_dir)
    logging(f'Log dir: {log_dir}')
    logging(f'Wandb dir: {wandb_dir}')
    logging(f'Agent save dir: {agent_save_dir}')
    logging(f'Agent best dir: {agent_best_dir}')
    logging(f'Reward save dir: {reward_save_dir}')
    logging(f'Video save dir: {video_save_dir}')
    return log_dir, wandb_dir, agent_save_dir, agent_best_dir, reward_save_dir, video_save_dir

def gen_frame(frame, irl_reward=None, learned_reward=None, true_reward=None, fontC='white'):
    import cv2
    font = cv2.FONT_HERSHEY_SIMPLEX
    fontScale = 0.5
    if fontC == 'white':
        fontColor = (255, 255, 255)
    else:
        # red
        fontColor = (0, 0, 255)
    lineType = 2
    if irl_reward is not None:
        frame = cv2.putText(frame, f'IRL reward: {irl_reward}', (10, 30), font, fontScale, fontColor, lineType)
    if learned_reward is not None:
        frame = cv2.putText(frame, f'Learned reward: {learned_reward}', (10, 60), font, fontScale, fontColor, lineType)
    if true_reward is not None:
        frame = cv2.putText(frame, f'True reward: {true_reward}', (10, 90), font, fontScale, fontColor, lineType)
    return frame

# def save_video(video_save_dir, frames, video_name='video', episode_id=0):
#     # save frames as video
#     import cv2
#     width, height, channels = frames[0].shape
#     frames_len = len(frames)

#     # Define the codec and create a VideoWriter object
#     # fourcc = cv2.VideoWriter_fourcc(*"MP4V")  # 'mp4v' for .mp4
#     fourcc = cv2.VideoWriter_fourcc(*"avc1")
#     video_path = os.path.join(video_save_dir, video_name + '_episode_{}'.format(episode_id) + '.mp4')
#     out = cv2.VideoWriter(video_path, fourcc, 4.0, (width, height))

#     # Loop over each frame in the video
#     for i in range(frames_len):
#         frame = frames[i]
#         out.write(frame)

#     # Release the VideoWriter
#     out.release()
#     print(f'Video saved at {video_path}')
#     return video_path

# def save_video(self, fname, frames, is_ood=False, fps=15.0):
def save_video(video_save_dir, frames, video_name='video', episode_id=0, fps=15.0):
        """ Saves @frames into a video with file name @fname. """
        # if is_ood:
        #     path = os.path.join(self._config.ood_record_dir, fname)
        # else:
        #     path = os.path.join(self._config.record_dir, fname)
        frames = [cv2.UMat.get(frame) if isinstance(frame, cv2.UMat) else frame for frame in frames]
        video_path = os.path.join(video_save_dir, video_name + '_episode_{}'.format(episode_id) + '.mp4')        
        logging("[*] Generating video: {}".format(video_path))

        def f(t):
            frame_length = len(frames)
            new_fps = 1.0 / (1.0 / fps + 1.0 / frame_length)
            idx = min(int(t * new_fps), frame_length - 1)
            return frames[idx]

        video = mpy.VideoClip(f, duration=len(frames) / fps + 2)

        video.write_videofile(filename=video_path, fps=fps, verbose=False)
        logging("[*] Video saved: {}".format(video_path))
        return video_path

class eval_mode(object):
    def __init__(self, *models):
        self.models = models

    def __enter__(self):
        self.prev_states = []
        for model in self.models:
            self.prev_states.append(model.training)
            model.train(False)

    def __exit__(self, *args):
        for model, state in zip(self.models, self.prev_states):
            model.train(state)
        return False

def evaluate(actor, env, args, 
             logger=None, num_episodes=10, vis=True, flag_gw=False, video_save_dir=None, learn_steps=0):
    """Evaluates the policy.
    Args:
      actor: A policy to evaluate.
      env: Environment to evaluate the policy on.
      num_episodes: A number of episodes to average the policy on.
    Returns:
      Averaged reward and a total number of steps.
    """
    total_timesteps = []
    total_returns = []
    total_found_goal = []
    flag_video_save = (learn_steps % args.eval.video_eval_interval == 0 and video_save_dir is not None)
    if flag_video_save:
        frame_buffer = []

    while len(total_returns) < num_episodes:
        state = env.reset()

        if flag_video_save:
            frame_buffer.append(env.render('rgb_array'))
        done = False

        with eval_mode(actor):
            while not done:
                action = actor.choose_action(state, sample=False)
                next_state, reward, done, info = env.step(action)
                if flag_video_save:
                    frame_buffer.append(gen_frame(env.render('rgb_array'), true_reward=reward))
                state = next_state

                if 'episode' in info.keys():
                    total_returns.append(info['episode']['r'])
                    total_timesteps.append(info['episode']['l'])
            if flag_gw:
                if reward > 0:
                    total_found_goal.append(1)
                else:
                    total_found_goal.append(0)
    if flag_video_save:
        video_save_path = save_video(video_save_dir, np.array(frame_buffer), episode_id=learn_steps)

    log_dict = {}
    if flag_gw:
        found_goal_rate = np.mean(total_found_goal)
        log_dict['eval/episode_found_goal'] = found_goal_rate
        # logger.log('eval/episode_found_goal', found_goal_rate, learn_steps)
    # logger.log('eval/episode_reward', np.mean(total_returns), learn_steps)
    log_dict['eval/episode_reward'] = np.mean(total_returns)
    logger.wandb_log(log_dict, learn_steps)
    if flag_video_save:
        logger.wandb_log_video(video_save_path, learn_steps, "eval/episode_video")
    # logger.dump(learn_steps, ty='eval')

    logging(f'Evaluated {len(total_returns)} episodes')
    return total_returns, total_timesteps, total_found_goal


def weighted_softmax(x, weights):
    x = x - torch.max(x, dim=0)[0]
    return weights * torch.exp(x) / torch.sum(
        weights * torch.exp(x), dim=0, keepdim=True)


def soft_update(net, target_net, tau):
    for param, target_param in zip(net.parameters(), target_net.parameters()):
        target_param.data.copy_(tau * param.data +
                                (1 - tau) * target_param.data)


def hard_update(source, target):
    for param, target_param in zip(source.parameters(), target.parameters()):
        target_param.data.copy_(param.data)


def weight_init(m):
    """Custom weight init for Conv2D and Linear layers."""
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        if hasattr(m.bias, 'data'):
            m.bias.data.fill_(0.0)


class MLP(nn.Module):
    def __init__(self,
                 input_dim,
                 hidden_dim,
                 output_dim,
                 hidden_depth,
                 output_mod=None):
        super().__init__()
        self.trunk = mlp(input_dim, hidden_dim, output_dim, hidden_depth,
                         output_mod)
        self.apply(weight_init)

    def forward(self, x):
        return self.trunk(x)


def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None):
    if hidden_depth == 0:
        mods = [nn.Linear(input_dim, output_dim)]
    else:
        mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
        for i in range(hidden_depth - 1):
            mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
        mods.append(nn.Linear(hidden_dim, output_dim))
    if output_mod is not None:
        mods.append(output_mod)
    trunk = nn.Sequential(*mods)
    return trunk


def get_concat_samples(policy_batch, expert_batch, args):
    online_batch_state, online_batch_next_state, online_batch_action, online_batch_reward, online_batch_done = policy_batch

    expert_batch_state, expert_batch_next_state, expert_batch_action, expert_batch_reward, expert_batch_done = expert_batch

    if args.method.type == "sqil":
        # convert policy reward to 0
        online_batch_reward = torch.zeros_like(online_batch_reward)
        # convert expert reward to 1
        expert_batch_reward = torch.ones_like(expert_batch_reward)

    batch_state = torch.cat([online_batch_state, expert_batch_state], dim=0)
    batch_next_state = torch.cat(
        [online_batch_next_state, expert_batch_next_state], dim=0)
    batch_action = torch.cat([online_batch_action, expert_batch_action], dim=0)
    batch_reward = torch.cat([online_batch_reward, expert_batch_reward], dim=0)
    batch_done = torch.cat([online_batch_done, expert_batch_done], dim=0)
    is_expert = torch.cat([torch.zeros_like(online_batch_reward, dtype=torch.bool),
                           torch.ones_like(expert_batch_reward, dtype=torch.bool)], dim=0)

    return batch_state, batch_next_state, batch_action, batch_reward, batch_done, is_expert


def save_state(tensor, path, num_states=5):
    """Show stack framed of images consisting the state"""

    tensor = tensor[:num_states]
    B, C, H, W = tensor.shape
    images = tensor.reshape(-1, 1, H, W).cpu()
    save_image(images, path, nrow=num_states)
    # make_grid(images)


def average_dicts(dict1, dict2):
    return {key: 1/2 * (dict1.get(key, 0) + dict2.get(key, 0))
                     for key in set(dict1) | set(dict2)}

def split_expert_memory(expert_memory: Memory, split_percent=0.8, args=None):
    expert_memory_train_size = int(expert_memory.size() * split_percent)
    expert_memory_train = Memory(expert_memory_train_size, args.seed)
    expert_memory_val = Memory(expert_memory.size() - expert_memory_train_size, args.seed)
    for i in range(expert_memory.size()):
        if i < expert_memory_train_size:
            expert_memory_train.add(expert_memory.buffer[i])
        else:
            expert_memory_val.add(expert_memory.buffer[i])
    return expert_memory_train, expert_memory_val




def get_irl_reward(agent, expert_obs, expert_next_obs, expert_action, expert_done, GAMMA, device=None):
    assert len(expert_obs.shape) == 4, f'Expert obs shape: {expert_obs.shape}'
    with torch.no_grad():
        q = agent.critic(expert_obs, expert_action)
        next_v = agent.getV(expert_next_obs)
        y = (1 - expert_done) * GAMMA * next_v
        irl_reward = q - y
    return irl_reward


def logging(*msg):
    # def prRed(prt): print("\033[91m {}\033[00m".format(prt))
    # def prGreen(prt): print("\033[92m {}\033[00m".format(prt))
    # def prYellow(prt): print("\033[93m {}\033[00m".format(prt))
    # def prLightPurple(prt): print("\033[94m {}\033[00m".format(prt))
    # def prPurple(prt): print("\033[95m {}\033[00m".format(prt))
    # def prCyan(prt): print("\033[96m {}\033[00m".format(prt))
    # def prLightGray(prt): print("\033[97m {}\033[00m".format(prt))
    # def prBlack(prt): print("\033[98m {}\033[00m".format(prt))

    print("{}>".format(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')), *msg)
