# source: https://github.com/gwthomas/IQL-PyTorch
# https://arxiv.org/pdf/2110.06169.pdf
import copy
import os
import random
import uuid
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import d4rl
import gym
import numpy as np
import pyrallis
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch.distributions import Normal
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.neighbors import KDTree

from torch import Tensor

class KD_tree:
    def __init__(self, data, k):
        assert isinstance(data, np.ndarray)
        assert len(data.shape) == 2 or len(data.shape) == 3
        if len(data.shape) == 3:
            data = data.reshape(-1, data.shape[-1])
        
        self.data = data
        self.kd_tree = KDTree(self.data)
        self.k = k

    def update(self,data):
        assert isinstance(data, np.ndarray)
        assert len(data.shape) == 2 or len(data.shape) == 3
        if len(data.shape) == 3:
            data = data.reshape(-1, data.shape[-1])
        self.data = data
        self.kd_tree = KDTree(self.data)

    def query(self,query_points):
        assert isinstance(query_points, np.ndarray)
        assert len(query_points.shape) == 1 or len(query_points.shape) == 2
        if len(query_points.shape) == 1:
            query_points = query_points[None,:]

        distances, _ = self.kd_tree.query(query_points, k=self.k, return_distance=True)
        
        return distances[:,-1][:,None] #(-1,1)


def soft_update(target: nn.Module, source: nn.Module, tau: float):
    for target_param, source_param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)


def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]:
    mean = states.mean(0)
    std = states.std(0) + eps
    return mean, std


def normalize_states(states: np.ndarray, mean: np.ndarray, std: np.ndarray):
    return (states - mean) / std


def wrap_env(
    env: gym.Env,
    state_mean: Union[np.ndarray, float] = 0.0,
    state_std: Union[np.ndarray, float] = 1.0,
    reward_scale: float = 1.0,
) -> gym.Env:
    # PEP 8: E731 do not assign a lambda expression, use a def
    def normalize_state(state):
        return (
            state - state_mean
        ) / state_std  # epsilon should be already added in std.

    def scale_reward(reward):
        # Please be careful, here reward is multiplied by scale!
        return reward_scale * reward

    env = gym.wrappers.TransformObservation(env, normalize_state)
    if reward_scale != 1.0:
        env = gym.wrappers.TransformReward(env, scale_reward)
    return env

def set_env_seed(env: Optional[gym.Env], seed: int):
    env.seed(seed)
    env.action_space.seed(seed)

def set_seed(
    seed: int, env: Optional[gym.Env] = None, deterministic_torch: bool = False
):
    if env is not None:
        set_env_seed(env, seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(deterministic_torch)
    
def is_goal_reached(reward: float, info: Dict) -> bool:
    if "goal_achieved" in info:
        return info["goal_achieved"]
    return reward > 0  # Assuming that reaching target is a positive reward

@torch.no_grad()
def eval_actor(
    env: gym.Env, actor: nn.Module, device: str, n_episodes: int, seed: int
) -> Tuple[np.ndarray, np.ndarray]:
    env.seed(seed)
    actor.eval()
    episode_rewards = []
    successes = []
    for _ in range(n_episodes):
        state, done = env.reset(), False
        episode_reward = 0.0
        goal_achieved = False
        while not done:
            action = actor.act(state, device)
            state, reward, done, env_infos = env.step(action)
            episode_reward += reward
            if not goal_achieved:
                goal_achieved = is_goal_reached(reward, env_infos)
        # Valid only for environments with goal
        successes.append(float(goal_achieved))
        episode_rewards.append(episode_reward)

    actor.train()
    return np.asarray(episode_rewards), np.mean(successes)


def return_reward_range(dataset: Dict, max_episode_steps: int) -> Tuple[float, float]:
    returns, lengths = [], []
    ep_ret, ep_len = 0.0, 0
    for r, d in zip(dataset["rewards"], dataset["terminals"]):
        ep_ret += float(r)
        ep_len += 1
        if d or ep_len == max_episode_steps:
            returns.append(ep_ret)
            lengths.append(ep_len)
            ep_ret, ep_len = 0.0, 0
    lengths.append(ep_len)  # but still keep track of number of steps
    assert sum(lengths) == len(dataset["rewards"])
    return min(returns), max(returns)


def modify_reward(dataset: Dict, env_name: str, max_episode_steps: int = 1000) -> Dict:
    if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
        min_ret, max_ret = return_reward_range(dataset, max_episode_steps)
        dataset["rewards"] /= max_ret - min_ret
        dataset["rewards"] *= max_episode_steps
        return {
            "max_ret": max_ret,
            "min_ret": min_ret,
            "max_episode_steps": max_episode_steps,
        }
    elif "antmaze" in env_name:
        dataset["rewards"] -= 1.0
    return {}


def modify_reward_online(reward: float, env_name: str, **kwargs) -> float:
    if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
        reward /= kwargs["max_ret"] - kwargs["min_ret"]
        reward *= kwargs["max_episode_steps"]
    elif "antmaze" in env_name:
        reward -= 1.0
    return reward

def asymmetric_l2_loss(u: torch.Tensor, tau: float) -> torch.Tensor:
    return torch.mean(torch.abs(tau - (u < 0).float()) * u**2)

def qlearning_dataset_with_timeouts(env, dataset=None,
                                    terminate_on_end=False,
                                    disable_goal=True,
                                    **kwargs):
    if dataset is None:
        dataset = env.get_dataset()

    N = dataset['rewards'].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    realdone_ = []
    if "infos/goal" in dataset:
        if not disable_goal:
            dataset["observations"] = np.concatenate(
                [dataset["observations"], dataset['infos/goal']], axis=1)
        else:
            pass

    episode_step = 0
    for i in range(N - 1):
        obs = dataset['observations'][i]
        new_obs = dataset['observations'][i + 1]
        action = dataset['actions'][i]
        reward = dataset['rewards'][i]
        done_bool = bool(dataset['terminals'][i])
        realdone_bool = bool(dataset['terminals'][i])
        if "infos/goal" in dataset:
            final_timestep = True if (dataset['infos/goal'][i] !=
                                    dataset['infos/goal'][i + 1]).any() else False
        else:
            final_timestep = dataset['timeouts'][i]

        if i < N - 1:
            done_bool += final_timestep

        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            continue
        if done_bool or final_timestep:
            episode_step = 0

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        realdone_.append(realdone_bool)
        episode_step += 1

    return {
        'observations': np.array(obs_),
        'actions': np.array(action_),
        'next_observations': np.array(next_obs_),
        'rewards': np.array(reward_)[:],
        'terminals': np.array(done_)[:],
        'realterminals': np.array(realdone_)[:],
    }
    
def split_into_trajectories(observations, actions, rewards, masks, dones_float,
                            next_observations):
    trajs = [[]]

    for i in range(len(observations)):
        trajs[-1].append((observations[i], actions[i], rewards[i], masks[i],
                          dones_float[i], next_observations[i]))
        if dones_float[i] == 1.0 and i + 1 < len(observations):
            trajs.append([])
            

    trajs = [{'observations': np.stack([t[0] for t in traj], axis=0),
              "actions": np.stack([t[1] for t in traj], axis=0),
              "rewards": np.stack([[t[2]] for t in traj], axis=0),
              "masks": np.stack([[t[3]] for t in traj], axis=0),
              "dones": np.stack([[t[4]] for t in traj], axis=0),
              "next_observations": np.stack([t[5] for t in traj], axis=0)
              } for traj in trajs if len(traj)>10] #

    return trajs


def test_mc(env, actor, device):
    eval_env = env
    state, done, iter = eval_env.reset(), False, 0

    rewards = []
    states = []
    actions = []

    final_states = []
    final_actions = []
    final_rewards = []
    n_mc_cutoff = 350
    while not done:
        action = actor.act(state, device)
        next_state, reward, done, _ = eval_env.step(action)
        rewards.append(reward)
        states.append(state)
        actions.append(action)

        iter += 1
        state = next_state
        if iter > 10000:
            break
        if done:
            for i in reversed(range(len(rewards) - 1)):
                rewards[i] = 0.99 * rewards[i + 1] + rewards[i]
            final_rewards = np.concatenate((final_rewards, rewards[:n_mc_cutoff]))
            final_states = final_states + states[:n_mc_cutoff]
            final_actions = final_actions + actions[:n_mc_cutoff]
            state, done = eval_env.reset(), False

            rewards = []
            states = []
            actions = []
            # print('reset', iter)
    return final_rewards, np.array(final_states), np.array(final_actions)

def log_bias_evaluation(env, agent, device):
    final_mc_list, final_obs_list, final_act_list = test_mc(env, agent.actor, device)
    obs_tensor = Tensor(final_obs_list).to(device)
    acts_tensor = Tensor(final_act_list).to(device)
    with torch.no_grad():
        q_prediction = agent.qf(obs_tensor, acts_tensor).cpu().numpy().reshape(-1)
    bias = q_prediction - final_mc_list
    final_mc_list_normalize_base = final_mc_list.copy()
    final_mc_list_normalize_base = np.abs(final_mc_list_normalize_base)
    final_mc_list_normalize_base[final_mc_list_normalize_base < 10] = 10
    normalized_bias_per_state = bias / final_mc_list_normalize_base
    
    return np.mean(normalized_bias_per_state)