import os
import math
import numpy as np
import torch as th
import random
from torch import nn
from typing import Iterable, List, Union
from pymoo.indicators.hv import HV
import rl.successor_features.pytorch_util as ptu


device = th.device('cuda' if th.cuda.is_available() else 'cpu')

def layer_init(layer, method='xavier', weight_gain=1, bias_const=0):
    if isinstance(layer, nn.Linear):
        if method == "xavier":
            th.nn.init.xavier_uniform_(layer.weight, gain=weight_gain)
        elif method == "orthogonal":
            th.nn.init.orthogonal_(layer.weight, gain=weight_gain)
        th.nn.init.constant_(layer.bias, bias_const)

def polyak_update(params: Iterable[th.nn.Parameter], target_params: Iterable[th.nn.Parameter], tau: float) -> None:
    with th.no_grad():
        for param, target_param in zip(params, target_params):
            if tau == 1:
                target_param.data.copy_(param.data)
            else:
                target_param.data.mul_(1.0 - tau)
                th.add(target_param.data, param.data, alpha=tau, out=target_param.data)

def huber(x, min_priority=0.01):
    return th.where(x < min_priority, 0.5 * x.pow(2), min_priority * x).mean()

def generate_weights(count=1, n=3, m=1):
    """
    Source: https://github.com/axelabels/DynMORL/blob/db15c29bc2cf149c9bda6b8890fee05b1ac1e19e/utils.py#L281
    """
    all_weights = []

    target = np.random.dirichlet(np.ones(n), 1)[0]
    prev_t = target
    for _ in range(count // m):
        target = np.random.dirichlet(np.ones(n), 1)[0]
        if m == 1:
            all_weights.append(target)
        else:
            for i in range(m):
                i_w = target * (i + 1) / float(m) + prev_t * (m - i - 1) / float(m)
                all_weights.append(i_w)
        prev_t = target + 0.

    return all_weights

def random_weights(dim, seed=None, n=1):
    """
    Generate random normalized weights from a Dirichlet distribution alpha=1
        Args:
            dim: size of the weight vector
    """
    if seed is not None:
        rng = np.random.default_rng(seed)
    else:
        rng = np.random
    weights = []
    for _ in range(n):
        w = rng.dirichlet(np.ones(dim))
        weights.append(w)
    if n == 1:
        return weights[0]
    return weights


def eval(agent, env, render=False):
    obs = env.reset()
    done = False
    total_reward, discounted_return = 0.0, 0.0
    gamma = 1.0
    while not done:
        if render:
            env.render()
        obs, r, done, info = env.step(agent.eval(obs))
        total_reward += r
        discounted_return += gamma * r
        gamma *= agent.gamma
    return total_reward, discounted_return


def eval_mo(agent, env, w, render=False):
    """
    Returns:
        w.total_reward, w.return, total vec reward, vec return
    """
    obs = env.reset()
    done = False
    total_vec_reward, vec_return = np.zeros_like(w), np.zeros_like(w)
    gamma = 1.0
    while not done:
        if render:
            env.render()
        obs, r, done, info = env.step(agent.eval(obs, w))
        total_vec_reward += info['phi']
        vec_return += gamma * info['phi']
        gamma *= agent.gamma
    return np.dot(w, total_vec_reward), np.dot(w,vec_return), total_vec_reward, vec_return

def policy_evaluation_mo(agent, env, w, rep=5, return_scalarized_value=False):
    """
    Returns vectorized value of the policy (mean of the returns)
    """
    if return_scalarized_value:
        returns = [eval_mo(agent, env, w)[1] for _ in range(rep)]
    else:
        returns = [eval_mo(agent, env, w)[3] for _ in range(rep)]
    return np.mean(returns, axis=0)

def reward_eval_mo(agent, env, w, render=False):
    """
    Returns:
        w.total_reward, w.return, total vec reward, vec return
    """
    obs = env.reset()[0]
    obs = obs.reshape(-1)
    done = False
    total_vec_reward, vec_return = np.zeros_like(w), np.zeros_like(w)
    total_reward, ep_return = 0, 0
    gamma = 1.0
    while not done:
        if render:
            env.render()
        next_obs, r, done, info = env.step(agent.eval(obs, w))
        next_obs = next_obs.reshape(-1)
        total_vec_reward += info['phi']
        vec_return += gamma * info['phi']
        gamma *= agent.gamma
        obs = next_obs
    return np.dot(w, total_vec_reward), np.dot(w, vec_return), total_vec_reward, vec_return

def reward_evaluation_mo(agent, env, w, rep=5, return_scalarized_value=False):
    """
    Returns vectorized value of the policy (mean of the returns)
    """
    if return_scalarized_value:
        returns = [reward_eval_mo(agent, env, w)[1] for _ in range(rep)]
    else:
        returns = [reward_eval_mo(agent, env, w)[3] for _ in range(rep)]
    return np.mean(returns, axis=0)

def get_tau(states, actions, tau_type='iqn', num_quantiles=32, fp=None):
    if tau_type == 'fix':
        presum_tau = ptu.zeros(1, num_quantiles) + 1. / num_quantiles
    elif tau_type == 'iqn':
        presum_tau = ptu.rand(1, num_quantiles) + 0.1
        presum_tau /= presum_tau.sum(dim=-1, keepdims=True)
    elif tau_type == 'fqf':
        if fp is None:
            fp = fp
        if type(actions) == int or type(actions) == np.int64:
            actions = th.unsqueeze(th.tensor(actions), dim=0).to(device)
            states = th.tensor(states).to(device)
            presum_tau = fp(states, actions)
            presum_tau = th.unsqueeze(presum_tau, dim=0)
        else:
            actions = th.tensor(actions).to(device)
            states = th.tensor(states).to(device)
            presum_tau = fp(states, actions)
        
    tau = th.cumsum(presum_tau, dim=1)  # (N, T), note that they are tau1...tauN in the paper
    with th.no_grad():
        tau_hat = ptu.zeros_like(tau)
        tau_hat[:, 0:1] = tau[:, 0:1] / 2.
        tau_hat[:, 1:] = (tau[:, 1:] + tau[:, :-1]) / 2.
    return tau, tau_hat, presum_tau

def dsf_eval_mo(agent, env, w, tau, presum_tau, render=False):
    """
    Returns:
        w.total_reward, w.return, total vec reward, vec return
    """
    obs = env.reset()
    done = False
    total_vec_reward, vec_return = np.zeros_like(w), np.zeros_like(w)
    gamma = 1.0
    while not done:
        if render:
            env.render()
        action = agent.eval(obs, w, tau, presum_tau)
        obs, r, done, info = env.step(action)
        total_vec_reward += info['phi']
        vec_return += gamma * info['phi']
        gamma *= agent.gamma
    return np.dot(w, total_vec_reward), np.dot(w,vec_return), total_vec_reward, vec_return

def dsf_policy_evaluation_mo(agent, env, w, tau, presum_tau, rep=5, return_scalarized_value=False):
    """
    Returns vectorized value of the policy (mean of the returns)
    """
    if return_scalarized_value:
        returns = [dsf_eval_mo(agent, env, w, tau, presum_tau)[1] for _ in range(rep)]
    else:
        returns = [dsf_eval_mo(agent, env, w, tau, presum_tau)[3] for _ in range(rep)]
    return np.mean(returns, axis=0)

def dsf_reward_eval_mo(agent, env, w, tau, presum_tau, render=False):
    """
    Returns:
        w.total_reward, w.return, total vec reward, vec return
    """
    obs = env.reset()[0]
    obs = obs.reshape(-1)
    done = False
    total_vec_reward, vec_return = np.zeros_like(w), np.zeros_like(w)
    gamma = 1.0
    while not done:
        if render:
            env.render()
        action = agent.eval(obs, w, tau, presum_tau)
        obs, r, done, info = env.step(action)
        obs = obs.reshape(-1)
        total_vec_reward += info['phi']
        vec_return += gamma * info['phi']
        gamma *= agent.gamma
    return np.dot(w, total_vec_reward), np.dot(w,vec_return), total_vec_reward, vec_return

def dsf_reward_evaluation_mo(agent, env, w, tau, presum_tau, rep=5, return_scalarized_value=False):
    """
    Returns vectorized value of the policy (mean of the returns)
    """
    if return_scalarized_value:
        returns = [dsf_reward_eval_mo(agent, env, w, tau, presum_tau)[1] for _ in range(rep)]
    else:
        returns = [dsf_reward_eval_mo(agent, env, w, tau, presum_tau)[3] for _ in range(rep)]
    return np.mean(returns, axis=0)

def eval_test_tasks(agent, env, tasks, rep=10):
    """
    Returns mean scalar value of the policy
    """
    returns = [policy_evaluation_mo(agent, env, w, rep=rep, return_scalarized_value=True) for w in tasks]
    return np.mean(returns, axis=0)

def moving_average(interval: Union[np.array,List], window_size: int) -> np.array:
    if window_size == 1:
        return interval
    window = np.ones(int(window_size))/float(window_size)
    return np.convolve(interval, window, 'same')

def linearly_decaying_epsilon(initial_epsilon, decay_period, step, warmup_steps, final_epsilon):
    """Returns the current epsilon for the agent's epsilon-greedy policy.
    This follows the Nature DQN schedule of a linearly decaying epsilon (Mnih et
    al., 2015). The schedule is as follows:
    Begin at 1. until warmup_steps steps have been taken; then
    Linearly decay epsilon from 1. to epsilon in decay_period steps; and then
    Use epsilon from there on.
    Args:
    decay_period: float, the period over which epsilon is decayed.
    step: int, the number of training steps completed so far.
    warmup_steps: int, the number of steps taken before epsilon is decayed.
    epsilon: float, the final value to which to decay the epsilon parameter.
    Returns:
    A float, the current epsilon value computed according to the schedule.
    """
    steps_left = decay_period + warmup_steps - step
    bonus = (initial_epsilon - final_epsilon) * steps_left / decay_period
    bonus = np.clip(bonus, 0., 1. - final_epsilon)
    return final_epsilon + bonus

def hypervolume(ref_point: np.ndarray, points: List[np.ndarray]) -> float:
    hv = HV(ref_point=ref_point*-1)
    return hv.do(np.array(points)*-1)

