
from PIL import Image
import numpy as np
import math
import torch

def check(input):
    if type(input) == np.ndarray:
        return torch.from_numpy(input)
        
def get_gard_norm(it):
    sum_grad = 0
    for x in it:
        if x.grad is None:
            continue
        sum_grad += x.grad.norm() ** 2
    return math.sqrt(sum_grad)

def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
    """Decreases the learning rate linearly"""
    lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs)))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def huber_loss(e, d):
    a = (abs(e) <= d).float()
    b = (e > d).float()
    return a*e**2/2 + b*d*(abs(e)-d/2)

def mse_loss(e):
    return e**2/2

def tile_images(img_nhwc):
    img_nhwc = np.asarray(img_nhwc)
    N, h, w, c = img_nhwc.shape
    H = int(np.ceil(np.sqrt(N)))
    W = int(np.ceil(float(N)/H))
    img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)])
    img_HWhwc = img_nhwc.reshape(H, W, h, w, c)
    img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)
    img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c)
    return img_Hh_Ww_c

def evaluate(agent, env, environment, num_evaluation=10, max_steps=None):
    episode_rewards = []
    if max_steps is None and environment == "mujoco":
        max_steps = 1000
    assert max_steps != None

    for eval_iter in range(num_evaluation):
        obs, s, _ = env.reset()
        episode_reward = 0
        for t in range(max_steps):

            actions = agent.step((np.array(obs)).astype(np.float32))
            action = actions.numpy()
            
            next_obs, next_s, reward, done, info, _ = env.step(action)
            episode_reward += reward[0,0,0]

            if done[0,0]:
                break
            obs = next_obs
        episode_rewards.append(episode_reward)
        
    return np.mean(episode_rewards)

