import numpy as np
import torch
import torch.nn.functional as F
import gym
import os
import json
import random
import math
import dmc2gym
# import imageio
#import metaworld
#import metaworld.envs.mujoco.env_dict as _env_dict

from collections import deque
from gym.wrappers import TimeLimit, RescaleAction
from gym import spaces
from rlkit.envs.wrappers import NormalizedBoxEnv
from collections import deque
from typing import Any

from torch import nn
from torch import distributions as pyd

from rlkit.envs.wrappers import ContinuousGridWorld
from dmc2gym import wrappers as dmc_wrappers

from reward_components import RewardComponentsWrapper

import gymnasium 
import gymnasium_robotics

def make_env(cfg, seed=None):
    """Helper function to create dm_control environment"""
    env_seed = cfg.seed if not seed else seed

    if "gym" in cfg.env:
        env = gym.make(cfg.env[4:])
        env = RescaleAction(env, min_action=-1, max_action=1)
        env = TimeLimit(env, 200)
    elif "ContGridWorld" in cfg.env:
        env = ContinuousGridWorld(walls=False)
        env.seed(env_seed)
        return TimeLimit(env, 40)
    elif "AntMaze" in cfg.env:
        gymnasium.register_envs(gymnasium_robotics)
        env = gymnasium.make('AntMaze_UMazeDense-v5')
        achieved_goal_shape = env.observation_space["achieved_goal"].shape[0]
        env.observation_space = env.observation_space["observation"]
        env.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(env.observation_space.shape[0]+achieved_goal_shape,),dtype=np.float64)
        env = GymnasiumToGymWrapper(env)
        env._max_episode_steps = 700
    elif "PointMaze" in cfg.env:
        gymnasium.register_envs(gymnasium_robotics)
        env = gymnasium.make('PointMaze_UMaze-v3')
        env.observation_space = env.observation_space["observation"]
        options = {"reset_cell":np.array([3,1]), "goal_cell":np.array([1,1])} 
        env = GymnasiumToGymWrapper(env, options)
        env = RewardWrapper(env) 
        env._max_episode_steps = 200
    else:
        if cfg.env == 'ball_in_cup_catch':
            domain_name = 'ball_in_cup'
            task_name = 'catch'
        else:
            domain_name = cfg.env.split('_')[0]
            task_name = '_'.join(cfg.env.split('_')[1:])

        env = dmc2gym.make(domain_name=domain_name,
                        task_name=task_name,
                        seed=cfg.seed,
                        visualize_reward=False)
        env = RewardComponentsWrapper(env)
        
    env.seed(env_seed)
    assert env.action_space.low.min() >= -1
    assert env.action_space.high.max() <= 1

    return env

def get_grid():
    env = ContinuousGridWorld(walls=False)
    return env


def ppo_make_env(env_id, seed):
    """Helper function to create dm_control environment"""
    if env_id == 'ball_in_cup_catch':
        domain_name = 'ball_in_cup'
        task_name = 'catch'
    else:
        domain_name = env_id.split('_')[0]
        task_name = '_'.join(env_id.split('_')[1:])

    env = dmc2gym.make(domain_name=domain_name,
                       task_name=task_name,
                       seed=seed,
                       visualize_reward=True)
    env.seed(seed)
    assert env.action_space.low.min() >= -1
    assert env.action_space.high.max() <= 1

    return env

def tie_weights(src, trg):
    assert type(src) == type(trg)
    trg.weight = src.weight
    trg.bias = src.bias
    
def make_metaworld_env(cfg):
    env_name = cfg.env.replace('metaworld_','')
    if env_name in _env_dict.ALL_V2_ENVIRONMENTS:
        env_cls = _env_dict.ALL_V2_ENVIRONMENTS[env_name]
    else:
        env_cls = _env_dict.ALL_V1_ENVIRONMENTS[env_name]
    
    env = env_cls()
    
    env._freeze_rand_vec = False
    env._set_task_called = True
    env.seed(cfg.seed)
    
    return TimeLimit(NormalizedBoxEnv(env), env.max_path_length)

def ppo_make_metaworld_env(env_id, seed):
    env_name = env_id.replace('metaworld_','')
    if env_name in _env_dict.ALL_V2_ENVIRONMENTS:
        env_cls = _env_dict.ALL_V2_ENVIRONMENTS[env_name]
    else:
        env_cls = _env_dict.ALL_V1_ENVIRONMENTS[env_name]
    
    env = env_cls()
    
    env._freeze_rand_vec = False
    env._set_task_called = True
    env.seed(seed)
    
    return TimeLimit(env, env.max_path_length)

class eval_mode(object):
    def __init__(self, *models):
        self.models = models

    def __enter__(self):
        self.prev_states = []
        for model in self.models:
            self.prev_states.append(model.training)
            model.train(False)

    def __exit__(self, *args):
        for model, state in zip(self.models, self.prev_states):
            model.train(state)
        return False


class train_mode(object):
    def __init__(self, *models):
        self.models = models

    def __enter__(self):
        self.prev_states = []
        for model in self.models:
            self.prev_states.append(model.training)
            model.train(True)

    def __exit__(self, *args):
        for model, state in zip(self.models, self.prev_states):
            model.train(state)
        return False

def soft_update_params(net, target_net, tau):
    for param, target_param in zip(net.parameters(), target_net.parameters()):
        target_param.data.copy_(tau * param.data +
                                (1 - tau) * target_param.data)

def set_seed_everywhere(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

def make_dir(*path_parts):
    dir_path = os.path.join(*path_parts)
    try:
        os.mkdir(dir_path)
    except OSError:
        pass
    return dir_path

def weight_init(m):
    """Custom weight init for Conv2D and Linear layers."""
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        if hasattr(m.bias, 'data'):
            m.bias.data.fill_(0.0)

class MLP(nn.Module):
    def __init__(self,
                 input_dim,
                 hidden_dim,
                 output_dim,
                 hidden_depth,
                 output_mod=None):
        super().__init__()
        self.trunk = mlp(input_dim, hidden_dim, output_dim, hidden_depth,
                         output_mod)
        self.apply(weight_init)

    def forward(self, x):
        return self.trunk(x)

class TanhTransform(pyd.transforms.Transform):
    domain = pyd.constraints.real
    codomain = pyd.constraints.interval(-1.0, 1.0)
    bijective = True
    sign = +1

    def __init__(self, cache_size=1):
        super().__init__(cache_size=cache_size)

    @staticmethod
    def atanh(x):
        return 0.5 * (x.log1p() - (-x).log1p())

    def __eq__(self, other):
        return isinstance(other, TanhTransform)

    def _call(self, x):
        return x.tanh()

    def _inverse(self, y):
        # We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
        # one should use `cache_size=1` instead
        return self.atanh(y)

    def log_abs_det_jacobian(self, x, y):
        # We use a formula that is more numerically stable, see details in the following link
        # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
        return 2.0 * (math.log(2.0) - x - F.softplus(-2.0 * x))
    
class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
    def __init__(self, loc, scale):
        self.loc = loc
        self.scale = scale

        self.base_dist = pyd.Normal(loc, scale)
        transforms = [TanhTransform()]
        super().__init__(self.base_dist, transforms)

    @property
    def mean(self):
        mu = self.loc
        for tr in self.transforms:
            mu = tr(mu)
        return mu
    
class TorchRunningMeanStd:
    def __init__(self, epsilon=1e-4, shape=(), device=None):
        self.mean = torch.zeros(shape, device=device)
        self.var = torch.ones(shape, device=device)
        self.count = epsilon

    def update(self, x):
        with torch.no_grad():
            batch_mean = torch.mean(x, axis=0)
            batch_var = torch.var(x, axis=0)
            batch_count = x.shape[0]
            self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count
        )

    @property
    def std(self):
        return torch.sqrt(self.var)


def update_mean_var_count_from_moments(
    mean, var, count, batch_mean, batch_var, batch_count
):
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta + batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + torch.pow(delta, 2) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count

def mlp(input_dim, hidden_dim, output_dim, hidden_depth, layernorm=False, output_mod=None,):
    if hidden_depth == 0:
        mods = [nn.Linear(input_dim, output_dim)]
    else:
        mods = [nn.Linear(input_dim, hidden_dim)]
        if layernorm:
            mods.append(nn.LayerNorm(hidden_dim))
        mods.append(nn.ReLU(inplace=True))
        
        for i in range(hidden_depth - 1):
            mods.append(nn.Linear(hidden_dim, hidden_dim))
            if layernorm:
                mods.append(nn.LayerNorm(hidden_dim))
            mods.append(nn.ReLU(inplace=True))
            
        mods.append(nn.Linear(hidden_dim, output_dim))
        
    if output_mod is not None:
        mods.append(output_mod)
        
    trunk = nn.Sequential(*mods)
    return trunk


def to_np(t):
    if t is None:
        return None
    elif t.nelement() == 0:
        return np.array([])
    else:
        return t.cpu().detach().numpy()

def copy_agent_target(agent, target):
    target.actor.load_state_dict(agent.actor.state_dict())
    target.actor_optimizer.load_state_dict(agent.actor_optimizer.state_dict())
    target.critic.load_state_dict(agent.critic.state_dict())
    target.critic_optimizer.load_state_dict(agent.critic_optimizer.state_dict())
    target.critic_target.load_state_dict(agent.critic_target.state_dict())
    target.log_alpha.data.copy_(agent.log_alpha.data)
    target.log_alpha_optimizer.load_state_dict(agent.log_alpha_optimizer.state_dict())


def eval_agents(current_agent, env, reward_model, n=10):
    average_r_hat = 0
    average_true_reward = 0
    
    for episode in range(n):
        obs = env.reset()
        current_agent.reset()
        done = False
        r_hat = 0
        true_reward = 0
        
        while not done:
            action = current_agent.act(obs, sample=False)
            
            # Get predicted reward
            predicted_reward = reward_model.r_hat(np.concatenate([obs, action], axis=-1))
            
            # Get true reward from environment
            obs, reward, done, extra = env.step(action)
            
            r_hat += predicted_reward
            true_reward += reward
            
        average_r_hat += r_hat
        average_true_reward += true_reward
    
    average_r_hat /= n
    average_true_reward /= n
    
    return (average_r_hat, average_true_reward)

def eval_agents(current_agent, env, reward_model, n=1):
    average_r_hat = 0
    average_true_reward = 0
    
    for episode in range(n):
        obs = env.reset()
        current_agent.reset()
        done = False
        r_hat = 0
        true_reward = 0
        
        while not done:
            action = current_agent.act(obs, sample=False)
            
            # Get predicted reward
            predicted_reward = reward_model.r_hat(np.concatenate([obs, action], axis=-1))
            
            # Get true reward from environment
            obs, reward, done, extra = env.step(action)
            
            r_hat += predicted_reward
            true_reward += reward
            
        average_r_hat += r_hat
        average_true_reward += true_reward
    
    average_r_hat /= n
    average_true_reward /= n
    
    return (average_r_hat, average_true_reward)


def ground_truth_train_acc(rm, buffer, num_labels):
    all_sa_t_1, all_sa_t_2, all_labels = buffer.sample_sa_and_labels(num_labels)
    total_steps = all_sa_t_1.shape[0]

    batch_size = 2048
    ensemble_acc = np.array([0 for _ in range(rm.de)])
    num_epochs = int(np.ceil(total_steps / batch_size))

    total = 0
    for batch_i in range(num_epochs):
        last_index = (batch_i + 1) * batch_size
        if (batch_i + 1) * batch_size > total_steps:
            last_index = total_steps

        sa_t_1 = all_sa_t_1[batch_i * batch_size : last_index]
        sa_t_2 = all_sa_t_2[batch_i * batch_size : last_index]
        labels = all_labels[batch_i * batch_size : last_index]
        labels = torch.from_numpy(labels.flatten()).long().to(rm.device)
        total += labels.size(0)

        for member in range(rm.de):
            r_hat1 = rm.r_hat_member(sa_t_1, member=member)
            r_hat2 = rm.r_hat_member(sa_t_2, member=member)
            r_hat = torch.cat([r_hat1, r_hat2], dim=-1)
            _, predicted = torch.max(r_hat.data, dim=-1)
            correct = (predicted == labels).sum().item()
            ensemble_acc[member] += correct

    ensemble_acc = ensemble_acc / total
    return ensemble_acc.mean()

class GymnasiumToGymWrapper(gym.Wrapper):
    def __init__(self, gymnasium_env, options=None, add_goal=False):
        super().__init__(gymnasium_env)
        self.gymnasium_env = gymnasium_env
        self.options = options
        self.add_goal = add_goal
        self.goal_cell = None

    def reset(self, **kwargs):
        if not self.options:
            obs, info = self.gymnasium_env.reset(**kwargs)
        else:
            obs, info = self.gymnasium_env.reset(options=self.options,**kwargs)

        self.goal_cell = obs["desired_goal"]
        
        obs = obs["observation"]
        if self.add_goal:
            obs = np.concatenate([obs["observation"],obs["achieved_goal"]])
        return obs
    
    def _unwrap_env(self, env):
        while hasattr(env, 'env'):
            env = env.env
        return env
        
    def step(self, action):
        obs, reward, terminated, truncated, info = self.gymnasium_env.step(action)
        obs = obs["observation"]
        if self.add_goal:
            obs = np.concatenate([obs["observation"],obs["achieved_goal"]])
        done = terminated or truncated

        return obs, reward, done, info
    
    def render(self, mode='human'):
        return self.gymnasium_env.render()
        
    def close(self):
        return self.gymnasium_env.close()
        
    def seed(self, seed=None):
        self._seed = seed  
        return [seed]
    
class RewardWrapper(gym.Wrapper):
    def __init__(self, env, reward_scale=1.0, reach_reward=10.0, reach_threshold=0.5):
        super().__init__(env)
        

    def step(self, action):
        observation, original_reward, done, info = self.env.step(action)
        
        if isinstance(observation, dict):
            x, y = observation["observation"][0], observation["observation"][1]
        else:
            x, y = observation[0], observation[1]
        
        position = np.array([x, y])
        
        reward = self.custom_reward(
            position, 
            goal_cell=self.goal_cell,
        )
                
        info['original_reward'] = original_reward
        
        return observation, reward, done, info
    
    def custom_reward(self, position, goal_cell):
        x, y = position[0], position[1] 
    
        # Check if goal reached
        euclidean_dist = np.linalg.norm(position - np.array(goal_cell))
        if euclidean_dist <= 0.3:
            return 1.0              
        if y < 0 and -2 < x < 0.5:
            dist = np.linalg.norm(position - np.array([1,-1])) * 3.5
        elif y < 0.5 and 0.5 < x < 2:
            dist = np.linalg.norm(position - np.array([1,1])) * 1
        else:
            dist = np.linalg.norm(position - np.array(goal_cell)) / 4
        
        reward = np.exp(-dist)
        
        return reward
    
class LogRes:
    def __init__(self, experiment_name=None, base_dir="./rewards"):
        self.experiment_name = experiment_name or f"agent_eval_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        self.rewards_dir = os.path.join(base_dir, self.experiment_name)
        self.rewards_file = os.path.join(self.rewards_dir, "rewards.json")
        self.rewards = {}
        
        os.makedirs(self.rewards_dir, exist_ok=True)
        
    def log(self, metrics_dict):
        step = self.step if hasattr(self, 'step') else 0
        
        for key, value in metrics_dict.items():
            if key not in self.rewards:
                self.rewards[key] = []
            self.rewards[key].append({"step": step, "value": float(value)})
        
        with open(self.rewards_file, 'w') as f:
            json.dump(self.rewards, f, indent=2)
        
        self.step = step + 1
        
    def set_step(self, step):
        self.step = step