import torch.nn as nn
import torch
import numpy as np
from typing import Any, Dict, List, Optional, Tuple, Union
import gym
import wandb
import uuid
import os
import random

import d4rl
from sklearn.neighbors import KDTree
import numpy as np
import gym
from torch import Tensor



class KD_tree:
    def __init__(self, data, k):
        assert isinstance(data, np.ndarray)
        assert len(data.shape) == 2 or len(data.shape) == 3
        if len(data.shape) == 3:
            data = data.reshape(-1, data.shape[-1])
        
        self.data = data
        self.kd_tree = KDTree(self.data)
        self.k = k

    def update(self,data):
        assert isinstance(data, np.ndarray)
        assert len(data.shape) == 2 or len(data.shape) == 3
        if len(data.shape) == 3:
            data = data.reshape(-1, data.shape[-1])
        self.data = data
        self.kd_tree = KDTree(self.data)

    def query(self,query_points):
        assert isinstance(query_points, np.ndarray)
        assert len(query_points.shape) == 1 or len(query_points.shape) == 2
        if len(query_points.shape) == 1:
            query_points = query_points[None,:]

        distances, _ = self.kd_tree.query(query_points, k=self.k, return_distance=True)
        
        return distances[:,-1][:,None] #(-1,1)


def soft_update(target: nn.Module, source: nn.Module, tau: float):
    for target_param, source_param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)


def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]:
    mean = states.mean(0)
    std = states.std(0) + eps
    return mean, std


def normalize_states(states: np.ndarray, mean: np.ndarray, std: np.ndarray):
    return (states - mean) / std


def wrap_env(
    env: gym.Env,
    state_mean: Union[np.ndarray, float] = 0.0,
    state_std: Union[np.ndarray, float] = 1.0,
    reward_scale: float = 1.0,
    reward_bias: float = 0.0,
) -> gym.Env:
    # PEP 8: E731 do not assign a lambda expression, use a def
    def normalize_state(state):
        return (
            state - state_mean
        ) / state_std  # epsilon should be already added in std.

    def scale_reward(reward):
        # Please be careful, here reward is multiplied by scale!
        return reward_scale * reward + reward_bias

    env = gym.wrappers.TransformObservation(env, normalize_state)
    if reward_scale != 1.0:
        env = gym.wrappers.TransformReward(env, scale_reward)
    return env

def set_env_seed(env: Optional[gym.Env], seed: int):
    env.seed(seed)
    env.action_space.seed(seed)


def set_seed(
    seed: int, env: Optional[gym.Env] = None, deterministic_torch: bool = False
):
    if env is not None:
        set_env_seed(env, seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(deterministic_torch)


def wandb_init(config: dict) -> None:
    wandb.init(
        config=config,
        project=config["project"],
        group=config["group"],
        name=config["name"],
        entity=config["entity"],
        allow_val_change=config["allow_val_change"],
        id=str(uuid.uuid4()),
    )
    wandb.run.save()
    
    
def is_goal_reached(reward: float, info: Dict) -> bool:
    if "goal_achieved" in info:
        return info["goal_achieved"]
    return reward > 0  # Assuming that reaching target is a positive reward


def return_reward_range(dataset: Dict, max_episode_steps: int) -> Tuple[float, float]:
    returns, lengths = [], []
    ep_ret, ep_len = 0.0, 0
    for r, d in zip(dataset["rewards"], dataset["terminals"]):
        ep_ret += float(r)
        ep_len += 1
        if d or ep_len == max_episode_steps:
            returns.append(ep_ret)
            lengths.append(ep_len)
            ep_ret, ep_len = 0.0, 0
    lengths.append(ep_len)  # but still keep track of number of steps
    assert sum(lengths) == len(dataset["rewards"])
    return min(returns), max(returns)

def modify_reward(
    dataset: Dict,
    env_name: str,
    max_episode_steps: int = 1000,
    reward_scale: float = 1.0,
    reward_bias: float = 0.0,
) -> Dict:
    modification_data = {}
    if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
        min_ret, max_ret = return_reward_range(dataset, max_episode_steps)
        dataset["rewards"] /= max_ret - min_ret
        dataset["rewards"] *= max_episode_steps
        modification_data = {
            "max_ret": max_ret,
            "min_ret": min_ret,
            "max_episode_steps": max_episode_steps,
        }
    dataset["rewards"] = dataset["rewards"] * reward_scale + reward_bias
    return modification_data


def modify_reward_online(
    reward: float,
    env_name: str,
    reward_scale: float = 1.0,
    reward_bias: float = 0.0,
    **kwargs,
) -> float:
    if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
        reward /= kwargs["max_ret"] - kwargs["min_ret"]
        reward *= kwargs["max_episode_steps"]
    reward = reward * reward_scale + reward_bias
    return reward

@torch.no_grad()
def eval_actor(
    env: gym.Env, actor: nn.Module, device: str, n_episodes: int, seed: int
) -> Tuple[np.ndarray, np.ndarray]:
    env.seed(seed)
    actor.eval()
    episode_rewards = []
    successes = []
    for _ in range(n_episodes):
        state, done = env.reset(), False
        episode_reward = 0.0
        goal_achieved = False
        while not done:
            action = actor.act(state, device)
            state, reward, done, env_infos = env.step(action)
            episode_reward += reward
            if not goal_achieved:
                goal_achieved = is_goal_reached(reward, env_infos)
        # Valid only for environments with goal
        successes.append(float(goal_achieved))
        episode_rewards.append(episode_reward)

    actor.train()
    return np.asarray(episode_rewards), np.mean(successes)


def qlearning_dataset_with_timeouts(env, dataset=None,
                                    terminate_on_end=False,
                                    disable_goal=True,
                                    **kwargs):
    if dataset is None:
        dataset = env.get_dataset()

    N = dataset['rewards'].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    realdone_ = []
    if "infos/goal" in dataset:
        if not disable_goal:
            dataset["observations"] = np.concatenate(
                [dataset["observations"], dataset['infos/goal']], axis=1)
        else:
            pass

    episode_step = 0
    for i in range(N - 1):
        obs = dataset['observations'][i]
        new_obs = dataset['observations'][i + 1]
        action = dataset['actions'][i]
        reward = dataset['rewards'][i]
        done_bool = bool(dataset['terminals'][i])
        realdone_bool = bool(dataset['terminals'][i])
        if "infos/goal" in dataset:
            final_timestep = True if (dataset['infos/goal'][i] !=
                                    dataset['infos/goal'][i + 1]).any() else False
        else:
            final_timestep = dataset['timeouts'][i]

        if i < N - 1:
            done_bool += final_timestep

        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            continue
        if done_bool or final_timestep:
            episode_step = 0

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        realdone_.append(realdone_bool)
        episode_step += 1

    return {
        'observations': np.array(obs_),
        'actions': np.array(action_),
        'next_observations': np.array(next_obs_),
        'rewards': np.array(reward_)[:],
        'terminals': np.array(done_)[:],
        'realterminals': np.array(realdone_)[:],
    }
    
def split_into_trajectories(observations, actions, rewards, masks, dones_float,
                            next_observations):
    trajs = [[]]

    for i in range(len(observations)):
        trajs[-1].append((observations[i], actions[i], rewards[i], masks[i],
                          dones_float[i], next_observations[i]))
        if dones_float[i] == 1.0 and i + 1 < len(observations):
            trajs.append([])
            
    trajs = [{'observations': np.stack([t[0] for t in traj], axis=0),
              "actions": np.stack([t[1] for t in traj], axis=0),
              "rewards": np.stack([[t[2]] for t in traj], axis=0),
              "masks": np.stack([[t[3]] for t in traj], axis=0),
              "dones": np.stack([[t[4]] for t in traj], axis=0),
              "next_observations": np.stack([t[5] for t in traj], axis=0)
              } for traj in trajs if len(traj)>100] #

    return trajs


def test_mc(env, actor, device):
    eval_env = env
    state, done, iter = eval_env.reset(), False, 0

    rewards = []
    states = []
    actions = []

    final_states = []
    final_actions = []
    final_rewards = []
    n_mc_cutoff = 350
    while not done:
        action = actor.act(state, device)
        next_state, reward, done, _ = eval_env.step(action)
        rewards.append(reward)
        states.append(state)
        actions.append(action)

        iter += 1
        state = next_state
        if iter > 10000:
            break
        if done:
            for i in reversed(range(len(rewards) - 1)):
                rewards[i] = 0.99 * rewards[i + 1] + rewards[i]
            final_rewards = np.concatenate((final_rewards, rewards[:n_mc_cutoff]))
            final_states = final_states + states[:n_mc_cutoff]
            final_actions = final_actions + actions[:n_mc_cutoff]
            state, done = eval_env.reset(), False

            rewards = []
            states = []
            actions = []
            # print('reset', iter)
    return final_rewards, np.array(final_states), np.array(final_actions)

def log_bias_evaluation(env, agent, device):
    final_mc_list, final_obs_list, final_act_list = test_mc(env, agent.actor, device)
    obs_tensor = Tensor(final_obs_list).to(device)
    acts_tensor = Tensor(final_act_list).to(device)
    with torch.no_grad():
        q_prediction = agent.critic_1(obs_tensor, acts_tensor).cpu().numpy().reshape(-1)
    bias = q_prediction - final_mc_list
    final_mc_list_normalize_base = final_mc_list.copy()
    final_mc_list_normalize_base = np.abs(final_mc_list_normalize_base)
    final_mc_list_normalize_base[final_mc_list_normalize_base < 10] = 10
    normalized_bias_per_state = bias / final_mc_list_normalize_base
    
    return np.mean(normalized_bias_per_state)