import torch
import numpy as np
from wrapper import *

class SuccessEvaluator:
    def __init__(self, eval_interval):
        self.eval_interval = eval_interval
        self.record = []

    def eval_success(self, info):
        if len(self.record)%10==0 and self.record!=[]:
            result = np.mean(self.record)
            self.record = []
            return result
        else:
            self.record.append(info['is_success'])
            return None


class RewardForwardFilter:
    def __init__(self, gamma):
        self.rewems = None
        self.gamma = gamma

    def update(self, rews):
        if self.rewems is None:
            self.rewems = rews
        else:
            self.rewems = self.rewems * self.gamma + rews
        return self.rewems


class RunningMeanStd:
    # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = np.zeros(shape, 'float64')
        self.var = np.ones(shape, 'float64')
        self.count = epsilon

    def update(self, x):
        batch_mean = np.mean(x, axis=0)
        batch_var = np.var(x, axis=0)
        batch_count = x.shape[0]
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        delta = batch_mean - self.mean
        tot_count = self.count + batch_count

        new_mean = self.mean + delta * batch_count / tot_count
        m_a = self.var * (self.count)
        m_b = batch_var * (batch_count)
        M2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
        new_var = M2 / (self.count + batch_count)

        new_count = batch_count + self.count

        self.mean = new_mean
        self.var = new_var
        self.count = new_count


def select_gpu():
    import pynvml
    pynvml.nvmlInit()
    id, free = 0, 0
    for gpu_index in range(pynvml.nvmlDeviceGetCount()):
        handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_index)
        info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        current_free = info.free / 1024**2
        if current_free > free:
            free = current_free
            id = gpu_index
    pynvml.nvmlShutdown()
    return id
        

def make_env(env_id, idx, capture_video, run_name, gamma, mode, distance_threshold):
    if "Fetch" in env_id:
        import gymnasium as gym
    else:
        import gym
    def thunk():
        if capture_video:
            env = gym.make(env_id, render_mode="rgb_array")
        else:
            env = gym.make(env_id)
        env = gym.wrappers.FlattenObservation(env)  # deal with dm_control's Dict observation space
        env = gym.wrappers.RecordEpisodeStatistics(env)
        if capture_video:
            if idx == 0:
                env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        env = gym.wrappers.ClipAction(env)
        env = gym.wrappers.NormalizeObservation(env)
        env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))
        env = gym.wrappers.NormalizeReward(env, gamma=gamma)
        env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))
        dense_reward = "Dense" in env_id
        if "FetchReach" in env_id:
            env = FetchReachEnvWrapper(env, mode=mode, dense_reward=dense_reward, distance_threshold=distance_threshold)
        elif "FetchPush" in env_id:
            env = FetchPushEnvWrapper(env, mode=mode, dense_reward=dense_reward, distance_threshold=distance_threshold)
        elif "FetchPickAndPlace" in env_id:
            env = FetchPickAndPlaceEnvWrapper(env, mode=mode, dense_reward=dense_reward, distance_threshold=distance_threshold)
        elif "FetchSlide" in env_id:
            env = FetchSlideEnvWrapper(env, mode=mode, dense_reward=dense_reward, distance_threshold=distance_threshold)
        return env
    return thunk