import cv2
import numpy as np
import torch
from torch._six import inf


def preprocessing(img):
    img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    img = cv2.resize(img, (84, 84), interpolation=cv2.INTER_AREA)
    return img


def stack_states(stacked_frames, state, is_new_episode):
    frame = preprocessing(state)

    if is_new_episode:
        stacked_frames = np.stack([frame for _ in range(4)], axis=0)
    else:
        stacked_frames = stacked_frames[1:, ...]
        stacked_frames = np.concatenate([stacked_frames, np.expand_dims(frame, axis=0)], axis=0)
    return stacked_frames

def explained_variance(ypred, y):
    assert y.ndim == 1 and ypred.ndim == 1
    vary = np.var(y)
    return np.nan if vary == 0 else 1 - np.var(y - ypred) / vary





class RunningMeanStd:
    def __init__(self, epsilon=1e-4, shape=(),momentum=None):
        self.shape=shape
        self.set_mean_var()
        self.count = epsilon
        self.momentum = momentum

    def set_mean_var(self):
        self.mean = np.zeros(self.shape, 'float32')
        self.var = np.ones(self.shape, 'float32')

    @staticmethod
    def mean_func(x,axis=0):
        return np.mean(x,axis=axis)

    @staticmethod
    def var_func(x, axis=0):
        return np.var(x, axis=axis)

    def set_from_checkpoint(self,checkpoint):
        self.mean=checkpoint["mean"]
        self.var=checkpoint["var"]
        self.count=checkpoint['count']

    @staticmethod
    def update_ema(old_data, new_data, momentum):
        if old_data is None:
            return new_data
        return old_data * momentum + new_data * (1.0 - momentum)


    def update(self, x):
        batch_mean,batch_var = self.mean_func(x,0), self.var_func(x,0)
        batch_count = x.shape[0]
        if self.momentum is not None and self.momentum>0 and self.momentum<1:
            self.mean = self.update_ema(self.mean, batch_mean, self.momentum)
            new_var = self.mean_func(np.square(x - self.mean))
            self.var = self.update_ema(self.var, new_var, self.momentum)
        else:
            self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = self.update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count)

    @staticmethod
    def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
        delta = batch_mean - mean
        tot_count = count + batch_count

        new_mean = mean + delta * batch_count / tot_count
        m_a = var * count
        m_b = batch_var * batch_count
        M2 = m_a + m_b + (delta ** 2) * count * batch_count / tot_count
        new_var = M2 / tot_count
        new_count = tot_count

        return new_mean, new_var, new_count

class RunningMeanStdTorch(RunningMeanStd):
    def __init__(self,device, epsilon=1e-4, shape=(),momentum=None):
        self.device=device
        super().__init__(epsilon,shape,momentum)

    @staticmethod
    def mean_func(x, axis=0):
        return torch.mean(x*1.0, dim=axis)

    @staticmethod
    def var_func(x, axis=0):
        return torch.var(x*1.0, dim=axis)

    def set_mean_var(self):
        self.mean = torch.zeros(self.shape,device=self.device,dtype=torch.float)
        self.var = torch.ones(self.shape,device=self.device,dtype=torch.float)


def clip_grad_norm_(parameters, norm_type: float = 2.0):
    """
    This is the official clip_grad_norm implemented in pytorch but the max_norm part has been removed.
    https://github.com/pytorch/pytorch/blob/52f2db752d2b29267da356a06ca91e10cd732dbc/torch/nn/utils/clip_grad.py#L9
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    norm_type = float(norm_type)
    if len(parameters) == 0:
        return torch.tensor(0.)
    device = parameters[0].grad.device
    if norm_type == inf:
        total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
    else:
        total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
                                norm_type)
    return total_norm
