import numpy as np
import torch
import torch.nn.functional as F
import gym
import os
import random
import math

try:
    import dmc2gym
except:
    dmc2gym = None
    print("Skipping dmc2gym")

try:
    import metaworld
    import metaworld.envs.mujoco.env_dict as _env_dict
except:
    metaworld = None
    _env_dict = None
    print("Skipping Metaworld imports")

from collections import deque
from gym.wrappers.time_limit import TimeLimit
from rlkit.envs.wrappers import NormalizedBoxEnv
from collections import deque
from skimage.util.shape import view_as_windows
from torch import nn
from torch import distributions as pyd
from matplotlib import pyplot as plt
import io
from PIL import Image


def make_env(cfg):
    """Helper function to create dm_control environment"""
    import custom_gridworld
    if 'gridworld' in cfg.env.name:
        env = custom_gridworld.gridworld_base.GridWorldEnv(size=cfg.env.params.size,
                                                           agent_loc=cfg.env.params.agent_location,
                                                           target_loc=cfg.env.params.target_location)
        env = custom_gridworld.reward_wrappers.ManhattanDistanceToGoalWrapper(env)
        if cfg.env.env_time_limit != -1:
            env = gym.wrappers.time_limit.TimeLimit(env, max_episode_steps=cfg.env.env_time_limit)
            env = gym.wrappers.time_limit.TimeLimit(env, max_episode_steps=cfg.env.env_time_limit)
        env = NpActionWrapper(env)
        return env

    if 'gym' in cfg.env.name:
        domain_name = cfg.env.name
        domain_name = domain_name.replace('gym_', '')
        env = gym.make(domain_name)
        # env = custom_gridworld.RGBObservationWrapper(env)
        if cfg.env.env_time_limit != -1:
            env = gym.wrappers.time_limit.TimeLimit(env, max_episode_steps=cfg.env.env_time_limit)

        if cfg.env.env_action_space == 'discrete':
            env = NpActionWrapper(env)

        # env = custom_gridworld.observation_wrappers.ProcessFrame(env)
        # env = custom_gridworld.observation_wrappers.FlattenObservation(env)
        return env

    if cfg.env.name == 'ball_in_cup_catch':
        domain_name = 'ball_in_cup'
        task_name = 'catch'
    else:
        domain_name = cfg.env.name.split('_')[0]
        task_name = '_'.join(cfg.env.name.split('_')[1:])

    env = dmc2gym.make(domain_name=domain_name,
                       task_name=task_name,
                       seed=cfg.seed,
                       visualize_reward=False)
    env.seed(cfg.seed)
    assert env.action_space.low.min() >= -1
    assert env.action_space.high.max() <= 1

    return env


def ppo_make_env(env_id, seed):
    """Helper function to create dm_control environment"""
    if env_id == 'ball_in_cup_catch':
        domain_name = 'ball_in_cup'
        task_name = 'catch'
    else:
        domain_name = env_id.split('_')[0]
        task_name = '_'.join(env_id.split('_')[1:])

    env = dmc2gym.make(domain_name=domain_name,
                       task_name=task_name,
                       seed=seed,
                       visualize_reward=True)
    env.seed(seed)
    assert env.action_space.low.min() >= -1
    assert env.action_space.high.max() <= 1

    return env


def tie_weights(src, trg):
    assert type(src) == type(trg)
    trg.weight = src.weight
    trg.bias = src.bias


def make_metaworld_env(cfg):
    env_name = cfg.env.name.replace('metaworld_', '')
    if env_name in _env_dict.ALL_V2_ENVIRONMENTS:
        env_cls = _env_dict.ALL_V2_ENVIRONMENTS[env_name]
    else:
        env_cls = _env_dict.ALL_V1_ENVIRONMENTS[env_name]

    env = env_cls()

    env._freeze_rand_vec = False
    env._set_task_called = True
    env.seed(cfg.seed)

    env = TimeLimit(NormalizedBoxEnv(env), env.max_path_length)
    env = AddEnvSpecForMaxSteps(env, env.max_path_length)  # to make sure that there is a spec

    return env


def ppo_make_metaworld_env(env_id, seed):
    env_name = env_id.replace('metaworld_', '')
    if env_name in _env_dict.ALL_V2_ENVIRONMENTS:
        env_cls = _env_dict.ALL_V2_ENVIRONMENTS[env_name]
    else:
        env_cls = _env_dict.ALL_V1_ENVIRONMENTS[env_name]

    env = env_cls()

    env._freeze_rand_vec = False
    env._set_task_called = True
    env.seed(seed)

    return TimeLimit(env, env.max_path_length)


class eval_mode(object):
    def __init__(self, *models):
        self.models = models

    def __enter__(self):
        self.prev_states = []
        for model in self.models:
            self.prev_states.append(model.training)
            model.train(False)

    def __exit__(self, *args):
        for model, state in zip(self.models, self.prev_states):
            model.train(state)
        return False


class train_mode(object):
    def __init__(self, *models):
        self.models = models

    def __enter__(self):
        self.prev_states = []
        for model in self.models:
            self.prev_states.append(model.training)
            model.train(True)

    def __exit__(self, *args):
        for model, state in zip(self.models, self.prev_states):
            model.train(state)
        return False


def soft_update_params(net, target_net, tau):
    for param, target_param in zip(net.parameters(), target_net.parameters()):
        target_param.data.copy_(tau * param.data +
                                (1 - tau) * target_param.data)


def set_seed_everywhere(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def make_dir(*path_parts):
    dir_path = os.path.join(*path_parts)
    try:
        os.mkdir(dir_path)
    except OSError:
        pass
    return dir_path


def weight_init(m):
    """Custom weight init for Conv2D and Linear layers."""
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        if hasattr(m.bias, 'data'):
            m.bias.data.fill_(0.0)


class MLP(nn.Module):
    def __init__(self,
                 input_dim,
                 hidden_dim,
                 output_dim,
                 hidden_depth,
                 output_mod=None):
        super().__init__()
        self.trunk = mlp(input_dim, hidden_dim, output_dim, hidden_depth,
                         output_mod)
        self.apply(weight_init)

    def forward(self, x):
        return self.trunk(x)


class TanhTransform(pyd.transforms.Transform):
    domain = pyd.constraints.real
    codomain = pyd.constraints.interval(-1.0, 1.0)
    bijective = True
    sign = +1

    def __init__(self, cache_size=1):
        super().__init__(cache_size=cache_size)

    @staticmethod
    def atanh(x):
        return 0.5 * (x.log1p() - (-x).log1p())

    def __eq__(self, other):
        return isinstance(other, TanhTransform)

    def _call(self, x):
        return x.tanh()

    def _inverse(self, y):
        # We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
        # one should use `cache_size=1` instead
        return self.atanh(y)

    def log_abs_det_jacobian(self, x, y):
        # We use a formula that is more numerically stable, see details in the following link
        # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
        return 2.0 * (math.log(2.0) - x - F.softplus(-2.0 * x))


class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
    def __init__(self, loc, scale):
        self.loc = loc
        self.scale = scale

        self.base_dist = pyd.Normal(loc, scale)
        transforms = [TanhTransform()]
        super().__init__(self.base_dist, transforms)

    @property
    def mean(self):
        mu = self.loc
        for tr in self.transforms:
            mu = tr(mu)
        return mu


class TorchRunningMeanStd:
    def __init__(self, epsilon=1e-4, shape=(), device=None):
        self.mean = torch.zeros(shape, device=device)
        self.var = torch.ones(shape, device=device)
        self.count = epsilon

    def update(self, x):
        with torch.no_grad():
            batch_mean = torch.mean(x, axis=0)
            batch_var = torch.var(x, axis=0)
            batch_count = x.shape[0]
            self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count
        )

    @property
    def std(self):
        return torch.sqrt(self.var)


def update_mean_var_count_from_moments(
        mean, var, count, batch_mean, batch_var, batch_count
):
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta + batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + torch.pow(delta, 2) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count


def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None):
    if hidden_depth == 0:
        mods = [nn.Linear(input_dim, output_dim)]
    else:
        mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
        for i in range(hidden_depth - 1):
            mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
        mods.append(nn.Linear(hidden_dim, output_dim))
    if output_mod is not None:
        mods.append(output_mod)
    trunk = nn.Sequential(*mods)
    return trunk


def to_np(t):
    if t is None:
        return None
    elif t.nelement() == 0:
        return np.array([])
    else:
        return t.cpu().detach().numpy()


def get_linear_fn(start: float, end: float, end_fraction: float):
    """
    Create a function that interpolates linearly between start and end
    between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``.
    This is used in DQN for linearly annealing the exploration fraction
    (epsilon for the epsilon-greedy strategy).

    :params start: value to start with if ``progress_remaining`` = 1
    :params end: value to end with if ``progress_remaining`` = 0
    :params end_fraction: fraction of ``progress_remaining``
        where end is reached e.g 0.1 then end is reached after 10%
        of the complete training process.
    :return:
    """

    def func(progress_remaining: float) -> float:
        if (1 - progress_remaining) > end_fraction:
            return end
        else:
            return start + (1 - progress_remaining) * (end - start) / end_fraction

    return func


class NpActionWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.env = env

    def step(self, action):
        if isinstance(action, np.ndarray):
            action = int(action.squeeze())
        return self.env.step(action)


def to_np_action(action):
    if not isinstance(action, np.ndarray):
        action = np.array(action)
    if len(action.shape) == 0:
        action = np.expand_dims(action, axis=0)
    else:
        action = action
    return action


def fig2img(fig):
    """Convert a Matplotlib figure to a PIL Image and return it"""
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    return img


def generate_gridworld_reward_plot(reward_model, env, title=""):
    env.reset()
    grid_size = env.size
    actions = env.action_space.n
    grid_rewards = np.zeros((grid_size, grid_size))
    for ix in range(grid_size):
        for jx in range(grid_size):
            env.set_agent_location(np.array([ix, jx]))
            obs = env.get_obs()
            rewards = []
            if reward_model is None:
                # we just use r(s) instead of r(s,a)
                rewards = [env.reward(0)]
            else:
                rewards = [reward_model.r_hat(np.concatenate([obs, to_np_action(a)])) for a in range(actions)]

            # get reward for best action :
            grid_rewards[ix, jx] = max(rewards)

    # creating plot
    fig, ax = plt.subplots()
    im = ax.imshow(grid_rewards)
    # Loop over data dimensions and create text annotations.
    for i in range(grid_size):
        for j in range(grid_size):
            text = ax.text(j, i, f"{grid_rewards[i, j]:.2f}", ha="center", va="center", color="w")
    ax.set_title(title)
    fig.tight_layout()
    img = fig2img(fig).convert('RGB')
    np_img = np.array(img)
    np_img = np.moveaxis(np_img, 2, 0)  # channels first
    return np_img


def compute_reward_epc(reward_model, env):
    # perform n rollouts & epc between these
    return NotImplemented


def generate_reward_similarity(reward_model, env):
    """
    Returns reward similarity measure between reward_model and env rewards
    """
    print("[WARN] Attempt to call reward similarity, Not Implemented")
    return 0


def plot_images(images, cols=1, save=None, show=True):
    assert isinstance(images, list), f"images should be list type, given {type(images)}"
    num_images = len(images)
    rows = math.ceil(num_images / cols)
    for ix, x in enumerate(images):
        plt.subplot(rows, cols, ix + 1)
        plt.gca().set_title(f'Im:{ix + 1}')
        plt.axis('off')
        plt.imshow(x)
    if save is not None:
        plt.savefig(save)
    if show:
        plt.show()


class MetaworldRGBObservationWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.env = env
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=self._get_obs_shape())

    def _get_obs_shape(self):
        obs = self.reset()
        return obs.shape

    def observation(self, obs):
        rgb_obs = self.env.render("rgb_array")
        return rgb_obs


from addict import Dict


class AddEnvSpecForMaxSteps(gym.Wrapper):
    def __init__(self, env, max_episode_steps):
        super().__init__(env)
        self.max_episode_steps = max_episode_steps
        self.myspec = Dict({'max_episode_steps': max_episode_steps})

    @property
    def spec(self):
        return self.myspec


class ERC_For_GT:
    def __init__(self, reward_model, size):
        self.reward_model = reward_model  # reward model stores trajectories & we can use the get_queries method.
        self.size = size

    @staticmethod
    def pearsonr(x, y):
        """
        Mimics `scipy.stats.pearsonr`
        Source : https://gist.github.com/ncullen93/58e71c4303b89e420bd8e0b0aa54bf48
        Arguments
        ---------
        x : 1D torch.Tensor
        y : 1D torch.Tensor
        Returns
        -------
        r_val : float
            pearsonr correlation coefficient between x and y
        """
        mean_x = torch.mean(x)
        mean_y = torch.mean(y)
        xm = x.sub(mean_x)
        ym = y.sub(mean_y)
        r_num = xm.dot(ym)
        r_den = torch.norm(xm, 2) * torch.norm(ym, 2)

        if r_num == 0:
            return torch.tensor(0)
        elif r_den == 0:
            return r_num / (r_den + 1e6)
        else:
            return r_num / r_den

    def compute(self):
        """
        r : reward function to compare against gt
        size: sze of the distribution D

        Assumptions :
        r is a reward funciton (s,a) -> R
        gt rewards are present in the trajectory buffer
        """

        sa_t_1, sa_t_2, r_t_1, r_t_2 = self.reward_model.get_queries(mb_size=self.size)

        # append sa_t_1 and sa_t_2, r_t_1 and r_t_2
        sa = np.concatenate([sa_t_1, sa_t_2], axis=0)
        r_gt = np.concatenate([r_t_1, r_t_2], axis=0)
        g_gt = torch.from_numpy(r_gt.squeeze(axis=-1).sum(axis=-1))  # get returns

        with torch.no_grad():
            r_pred = self.reward_model.r_hat_batch(sa).squeeze(axis=-1)
            g_pred = torch.from_numpy(r_pred.sum(axis=-1)).to(torch.float64)

        p = self.pearsonr(g_gt, g_pred)
        distance = torch.sqrt((1 - p) / 2)
        return distance

    def log(self, logger, step):
        distance = self.compute()
        logger.log('train/erc', distance, step)
