from typing import Iterable, List, Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch as th
import torch.nn.functional as F
import jax.numpy as jnp
import flax.linen as nn
from gymnasium.spaces import Box, Discrete
from pymoo.indicators.hv import HV
from scipy.spatial import ConvexHull



def eval(agent, env, render=False):
    obs, _ = env.reset()
    done = False
    total_reward, discounted_return = 0.0, 0.0
    gamma = 1.0
    while not done:
        if render:
            env.render()
        obs, r, terminated, truncated, info = env.step(agent.eval(obs))
        done = terminated or truncated
        total_reward += r
        discounted_return += gamma * r
        gamma *= agent.gamma
    return total_reward, discounted_return


def eval_mo(agent, env, w):
    """
    Returns:
        w.total_reward, w.return, total vec reward, vec return
    """
    obs, _ = env.reset()
    done = False
    total_vec_reward, vec_return = np.zeros_like(w), np.zeros_like(w)
    gamma = 1.0
    while not done:
        obs, r, terminated, truncated, info = env.step(agent.eval(obs, w))
        done = terminated or truncated
        total_vec_reward += info["vector_reward"]
        vec_return += gamma * info["vector_reward"]
        gamma *= agent.gamma
    return (
        np.dot(w, total_vec_reward),
        np.dot(w, vec_return),
        total_vec_reward,
        vec_return,
    )


def policy_evaluation_mo(agent, env, w, rep=5, return_scalarized_value=False):
    """Returns vectorized value of the policy (mean of the returns)"""
    if return_scalarized_value:
        returns = [eval_mo(agent, env, w)[1] for _ in range(rep)]
    else:
        returns = [eval_mo(agent, env, w)[3] for _ in range(rep)]
    return np.mean(returns, axis=0)


def eval_test_tasks(agent, env, tasks, rep=10):
    """Returns mean scalar value of the policy"""
    returns = [policy_evaluation_mo(agent, env, w, rep=rep, return_scalarized_value=True) for w in tasks]
    return np.mean(returns, axis=0)


def best_vector(values, w):
    max_v = values[0]
    for i in range(1, len(values)):
        if values[i] @ w > max_v @ w:
            max_v = values[i]
    return max_v


def visualize_eval_jax(agent, env, model=None, w=None, horizon=10, init_obs=None, compound=True, deterministic=True, show=False, filename=None):
    from gpi.dynamics.util_jax import ModelEnv

    if init_obs is None:
        init_obs, _ = env.reset()
    obs_dim = env.observation_space.shape[0]
    actions = []
    real_obs = []
    real_rewards = []
    real_phis = []
    obs = init_obs.copy()
    for step in range(horizon):
        if w is not None:
            act = agent.eval(obs, w)
        else:
            act = agent.eval(obs)
        actions.append(act)
        obs, r, terminated, truncated, info = env.step(act)
        done = terminated or truncated
        real_obs.append(obs.copy())
        real_rewards.append(r)
        if 'vector_reward' in info:
            real_phis.append(info['vector_reward'])
        if done: 
            break
    
    model_obs = []
    model_obs_stds = []
    model_rewards_stds = []
    model_rewards = []
    if model is not None:
        obs = init_obs.copy()
        model_env = ModelEnv(model, env_id=env.unwrapped.spec.id, rew_dim=1 if w is None else len(w))
        acts = actions
        if isinstance(env.action_space, Discrete):
            acts = nn.one_hot(acts, num_classes=env.action_space.n)
        for step in range(len(real_obs)):
            if compound or step == 0:
                obs, r, done, info = model_env.step(obs, acts[step], deterministic=deterministic)
            else:
                obs, r, done, info = model_env.step(real_obs[step - 1], acts[step], deterministic=deterministic)
            model_obs.append(obs.copy())
            model_obs_stds.append(np.sqrt(info['var_obs'].copy()))
            model_rewards_stds.append(np.sqrt(info['var_rewards'].copy()))
            model_rewards.append(r)
            #if done:
            #    break

    num_plots = obs_dim + (1 if w is None else len(w)) + 1
    num_cols = int(np.ceil(np.sqrt(num_plots)))
    num_rows = int(np.ceil(num_plots / num_cols))
    x = np.arange(0, len(real_obs))
    fig, axs = plt.subplots(num_rows, num_cols, figsize=(20, 15))
    axs = np.array(axs).reshape(-1)
    for i in range(num_plots):
        if i == num_plots - 1:
            axs[i].set_ylabel(f"Action")
            axs[i].grid(alpha=0.25)
            axs[i].plot(x, [actions[step] for step in x], label='Action', color='orange')
        elif i >= obs_dim:
            axs[i].set_ylabel(f"Reward {i - obs_dim}")
            axs[i].grid(alpha=0.25)
            if w is not None:
                axs[i].plot(x, [real_phis[step][i - obs_dim] for step in x], label='Environment', color='black')
            else:
                axs[i].plot(x, [real_rewards[step] for step in x], label='Environment', color='black')
            if model is not None:
                axs[i].plot(x, [model_rewards[step][i - obs_dim] for step in x], label='Model', color='blue')
                axs[i].fill_between(x, [model_rewards[step][i - obs_dim] + model_rewards_stds[step][i - obs_dim] for step in x], [model_rewards[step][i - obs_dim] - model_rewards_stds[step][i - obs_dim] for step in x], alpha=0.2, facecolor='blue')
        else:
            axs[i].set_ylabel(f"State {i}")
            axs[i].grid(alpha=0.25)
            axs[i].plot(x, [real_obs[step][i] for step in x], label='Environment', color='black')
            if model is not None:
                axs[i].plot(x, [model_obs[step][i] for step in x], label='Model', color='blue')
                axs[i].fill_between(x, [model_obs[step][i] + model_obs_stds[step][i] for step in x], [model_obs[step][i] - model_obs_stds[step][i] for step in x], alpha=0.2, facecolor='blue')
    sns.despine()
    if filename is not None:
        plt.savefig(filename + '.pdf', format='pdf', bbox_inches='tight')
    if show:
        plt.show()
    return plt

def hypervolume(ref_point: np.ndarray, points: List[np.ndarray]) -> float:
    return HV(ref_point=ref_point * - 1)(np.array(points) * - 1)


def make_gif(env, agent, weight: np.ndarray, fullpath: str, fps: int = 50, lenght: int = 300):
    frames = []
    state, info = env.reset()
    terminated, truncated = False, False
    while not (terminated or truncated) and len(frames) < lenght:
        frame = env.render()
        frames.append(frame)
        action = agent.eval(state, weight)
        state, reward, terminated, truncated, info = env.step(action)
    env.close()

    from moviepy.editor import ImageSequenceClip
    clip = ImageSequenceClip(list(frames), fps=fps)
    clip.write_gif(fullpath + '.gif', fps=fps)
    print("Saved gif at: " + fullpath + '.gif')
