import gym
import wandb

import copy
import random
import numpy as np

import pprint
import haiku as hk

from typing import Any

from jax.config import config
from dotenv import load_dotenv
from pydantic import BaseSettings, Field

from policies.ebm_policy import IRCP
from policies.rc_policy import RCP


load_dotenv()

config.update("jax_platform_name", "gpu")  # Default to GPU.

rng_seq = hk.PRNGSequence(42)
pp = pprint.PrettyPrinter(width=41, compact=True)


class ParticleEnv(gym.Env):
    def __init__(
        self, max_steps=50, dims=None, delta=1.0, delta_goal=0.1, discountinuities=True, epsilon=None
    ):
        self.max_steps = max_steps
        self.dims = dims
        self.delta = delta
        self.delta_goal = delta_goal
        self.eps_goal = 1e-2 * np.sqrt(dims) # to be dims invariant
        self.epsilon = epsilon

    def reset(self):
        self.state = np.random.uniform(-self.delta, self.delta, size=self.dims)
        self.goal1 = np.random.uniform(-self.delta_goal, self.delta_goal, size=self.dims)
        self.goal2 = np.random.uniform(-self.delta_goal, self.delta_goal, size=self.dims)
        if self.epsilon > 0:
            self.goal2[0] = self.goal1[0] + np.random.uniform(-self.epsilon, self.epsilon, size=1)
            self.goal2 = np.clip(self.goal2, -self.delta_goal, self.delta_goal)

        self.goal = (
            copy.deepcopy(self.goal1) if np.random.rand() <= 0.5 else copy.deepcopy(self.goal2)
        )
        if self.goal1[0] > self.goal2[0]:
            # if np.linalg.norm(self.goal1) > np.linalg.norm(self.goal2) and self.discountinuities:
            self.goal_values = np.array([1.0, -1.0])
        else:
            self.goal_values = np.array([-1.0, 1.0])
        self.total_steps = 0
        return np.concatenate([copy.deepcopy(self.state), copy.deepcopy(self.goal1), copy.deepcopy(self.goal2)], -1)

    def compute_reward(self, state):
        distances = np.array([np.linalg.norm(state - g) for g in [self.goal1, self.goal2]])
        close = np.float32(distances < self.eps_goal)
        if np.all(close):
            idx = np.argmax(distances)
            close[idx] = 0
        # TODO what if both goals are very close to each other?
        return close @ self.goal_values

    def get_action(self):
        action = copy.deepcopy(self.goal) - copy.deepcopy(self.state)
        action = np.clip(action, -1, 1)
        return action

    def step(self, action):
        self.total_steps += 1
        assert self.state.shape == action.shape
        next_state = copy.deepcopy(self.state) + copy.deepcopy(action)
        next_state = np.clip(next_state, -self.delta, self.delta)

        reward = self.compute_reward(next_state)
        done = (reward != 0) or (self.total_steps > self.max_steps)
        self.state = copy.deepcopy(next_state)
        next_state = np.concatenate(
            [copy.deepcopy(self.state), copy.deepcopy(self.goal1), copy.deepcopy(self.goal2)], -1
        )
        return next_state, reward, done, {}


class ReplayBuffer:
    def __init__(self, capacity=1_000_000):
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, state, action, rtg):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, rtg)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, rtg = map(np.stack, zip(*batch))
        return {"states": state, "actions": action, "rtg": np.reshape(rtg, (-1, 1))}

    def __len__(self):
        return len(self.buffer)


def get_data():
    buffer = ReplayBuffer()
    env = ParticleEnv(dims=conf.env_dims, epsilon=0.0, delta_goal=0.1)

    nb_samples = 0
    while nb_samples < 1e5:
        state = env.reset()
        done = False
        traj = []
        rewards = []
        while not done:
            nb_samples += 1

            action = env.get_action()
            traj.append((state, action))
            next_state, reward, done, _ = env.step(action)
            state = next_state
            rewards.append(reward)

        rewards = np.array(rewards)
        for i, (state, action) in enumerate(traj):
            rtg = rewards.sum()
            buffer.push(state, action, rtg)

    return buffer


def toy_eval(env, agent, agent_state):
    for delta_goal in [0.1, 0.25, 0.5, 0.75, 1.0]:
        dfs = []
        env = ParticleEnv(dims=env.dims, delta_goal=delta_goal, max_steps=5, epsilon=0.0)

        for _ in range(100):
            state = env.reset()
            done = False
            reward = 0
            while not done:
                action = agent.get_action(next(rng_seq), agent_state, state, alpha=conf.alpha)
                next_state, reward, done, _ = env.step(np.squeeze(action))
                state = next_state

            dfs.append(np.clip(reward, 0, 1))

        logs = {"delta_goal": delta_goal, "reward": np.array(dfs).mean(), "epsilon": 0.0}
        wandb.log(logs)

    for epsilon in [0.25, 0.5, 1.0, 1e-1, 1e-2, 1e-3, 1e-4]:
        dfs = []
        env = ParticleEnv(dims=env.dims, delta_goal=0.1, max_steps=5, epsilon=epsilon)

        for _ in range(100):
            state = env.reset()
            done = False
            reward = 0
            while not done:
                action = agent.get_action(next(rng_seq), agent_state, state, alpha=conf.alpha)
                next_state, reward, done, _ = env.step(np.squeeze(action))
                state = next_state

            dfs.append(np.clip(reward, 0, 1))

        logs = {"delta_goal": 0.1, "reward": np.array(dfs).mean(), "epsilon": epsilon}
        wandb.log(logs)


def sgd_step(num_steps, agent, agent_state):
    for i in range(num_steps):
        batch = buffer.sample(conf.batch_size)
        agent_state, logs = agent.sgd_step(next(rng_seq), agent_state, batch)
        if i % 1000 == 0:
            wandb.log(logs)

    print("\nDone running gradient steps.\n")

    return agent_state


class Settings(BaseSettings):
    env_name: str = Field(..., description="D4RL env name")
    dataset: str
    ds_version: str
    wandb_project: str
    wandb_entity: str
    algo: Any
    mode: str
    loss_type: str
    lr: float
    clip: float
    data_noise: float
    alpha: float
    eta: float
    scale: float
    density_penalty: float
    seed: int
    batch_size: int
    env_dims: int
    epsilon: float
    num_eval_episodes: int
    max_iters: int
    dims: int
    num_steps_per_iter: int
    num_mcmc_chains: int
    num_action_samples: int
    weight_decay: float
    reward_target: float
    ema: float
    spectral_norm: bool
    use_bias: bool
    use_layer_norm: bool
    use_net2: bool
    wab_log: bool
    all_grad_penalty: bool
    save_ckpt: bool

    class Config:
        env_prefix = ""
        case_sentive = False
        env_file = ".env"
        env_file_encoding = "utf-8"


if __name__ == "__main__":
    conf = Settings()

    algos = {"RCP": RCP, "IRCP": IRCP}
    algo = algos[conf.algo]

    exp_prefix = f"{conf.env_name}{conf.dataset}-{random.randint(int(1e5), int(1e6) - 1)}"
    wandb.init(name=exp_prefix, project=conf.wandb_project, entity=conf.wandb_entity, config=conf)

    env = ParticleEnv(dims=conf.env_dims, epsilon=0.0, delta_goal=0.1)

    num_steps = conf.max_iters * conf.num_steps_per_iter
    buffer = get_data()
    obs_spec_shape = env.dims * 3
    action_spec_shape = env.dims

    agent = algo(state_dims=obs_spec_shape, actions_dims=action_spec_shape, **conf.dict())

    agent_state = agent.get_init_state(next(rng_seq), conf.lr)

    new_agent_state = sgd_step(num_steps, agent, agent_state)

    toy_eval(env, agent, new_agent_state)
