import numpy as np
import scipy.signal

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
from torch.distributions import MultivariateNormal
import torch.optim as optim
import pandas as pd
from warnings import filterwarnings
import random
from torch.distributions import MultivariateNormal

from utils import *

filterwarnings(
    action="ignore",
    category=DeprecationWarning,
    message="`np.bool8` is a deprecated alias for `np.bool_`",
)

LOG_STD_MAX = 2
LOG_STD_MIN = -20

# import logging

# # Colors
# RED = "\033[91m"
# GREEN = "\033[92m"
# BLUE = "\033[94m"
# CYAN = "\033[96m"
# YELLOW = "\033[93m"
# MAGENTA = "\033[95m"
# ENDC = "\033[0m"  # Reset color

# logger = logging.getLogger(__name__)


def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes) - 1):
        act = activation if j < len(sizes) - 2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j + 1]), act()]
    return nn.Sequential(*layers)



class MLPQFunction(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.q = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs, act):
        q = self.q(torch.cat([obs, act], dim=-1))
        return torch.squeeze(q, -1)  # Critical to ensure q has right shape.


def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])

