import glob
import os

import torch
import torch.nn as nn

# from a2c_ppo_acktr.envs import VecNormalize
from PIL import Image
import math


class AdaptiveKLController:
      def __init__(self, init_kl_coef, target, horizon, kl_beta_lb):
          self.value = init_kl_coef
          self.target = target
          self.horizon = horizon
          self.kl_beta_lb = kl_beta_lb
        
      def update(self, current, n_steps):
          if self.value < self.kl_beta_lb:
              self.value = self.kl_beta_lb
          else:
            proportional_error = torch.clamp(current / self.target - 1, -0.2, 0.2).item()
            mult = 1 + proportional_error * n_steps / self.horizon
            self.value *= mult


class RunningNormalizer:
    def __init__(self, epsilon=1e-8):
        self.count = 0
        self.mean = 0.0
        self.M2 = 0.0   # sum of squared differences
        self.epsilon = epsilon

    def update(self, x: float):
        """Update running mean/variance with a new value x."""
        self.count += 1
        delta = x - self.mean
        self.mean += delta / self.count
        delta2 = x - self.mean
        self.M2 += delta * delta2

    @property
    def variance(self):
        return self.M2 / self.count if self.count > 0 else 0.0

    @property
    def std(self):
        return math.sqrt(self.variance + self.epsilon)

    def normalize(self, x: float) -> float:
        """Normalize a value using current running stats."""
        if self.count < 2:  # not enough stats yet
            return 0.0
        return (x - self.mean) / self.std


# Get a render function
def get_render_func(venv):
    if hasattr(venv, "envs"):
        return venv.envs[0].render
    elif hasattr(venv, "venv"):
        return get_render_func(venv.venv)
    elif hasattr(venv, "env"):
        return get_render_func(venv.env)

    return None


# def get_vec_normalize(venv):
#     if isinstance(venv, VecNormalize):
#         return venv
#     elif hasattr(venv, "venv"):
#         return get_vec_normalize(venv.venv)

#     return None


# Necessary for my KFAC implementation.
class AddBias(nn.Module):
    def __init__(self, bias):
        super(AddBias, self).__init__()
        self._bias = nn.Parameter(bias.unsqueeze(1))

    def forward(self, x):
        if x.dim() == 2:
            bias = self._bias.t().view(1, -1)
        else:
            bias = self._bias.t().view(1, -1, 1, 1)

        return x + bias


def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
    """Decreases the learning rate linearly"""
    lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs)))
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def init(module, weight_init, bias_init, gain=1):
    weight_init(module.weight.data, gain=gain)
    bias_init(module.bias.data)
    return module


def cleanup_log_dir(log_dir):
    try:
        os.makedirs(log_dir)
    except OSError:
        files = glob.glob(os.path.join(log_dir, "*.monitor.csv"))
        for f in files:
            os.remove(f)


def image_wrap(image_np_array):
    return Image.fromarray(image_np_array[0])
