import pickle
import random

import jax
import gym
import jax.numpy as jnp
import d4rl  # noqa

import wandb
import haiku as hk

import numpy as np
EPS = np.finfo(np.float32).eps

from pathlib import Path
from typing import Any, Dict
from scipy.stats import binned_statistic

from tqdm import tqdm
from dotenv import load_dotenv
from pydantic import BaseSettings, Field

from policies.ebm_policy import IRCP
from policies.rc_policy import RCP
from utils import MinMaxNormalizationLayer, MinMaxDenormalizationLayer

from utils import get_job_config
from eval_episodes import evaluate_episode_rtg



load_dotenv()

algos = {"RCP": RCP, "IRCP": IRCP,}


def discount_cumsum(x, gamma):
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0] - 1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
    return discount_cumsum


def experiment(conf):
    wab_log = conf.wab_log
    env_name = conf.env_name
    dataset_mode = conf.dataset
    ds_version = conf.ds_version
    job_id = get_job_config()["id"]
    name = f"{env_name}-{dataset_mode}-{ds_version}"

    chkp_path = Path("/checkpoints") / job_id
    chkp_path.mkdir(parents=True, exist_ok=True)

    env = gym.make(name)

    rng_seq = hk.PRNGSequence(conf.seed)

    state_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    group_name = f"{env_name}-{dataset_mode}"
    exp_prefix = f"{group_name}-{random.randint(int(1e5), int(1e6) - 1)}"


    all_states = []
    all_actions = []
    all_rtg = []
    
    ep_return = np.array([path['rewards'].sum() for path in trajectories])

    idx = np.argsort(ep_return)
    if conf.filter_:
        idx = idx[len(idx)//2:]
        trajectories = [trajectories[i] for i in idx]

    ep_len = np.array([len(path['rewards']) for path in trajectories])
    max_ep_len = np.max(ep_len)

    for path in trajectories:
        if conf.average_reward:
            rtg = np.array([path["rewards"][si:].sum() / (max_ep_len - si) for si in range(len(path["rewards"]))])
        else:
            rtg = np.array([path["rewards"][si:].sum() for si in range(len(path["rewards"]))])

        observations = path["observations"]
        actions = path['actions']
        rewards = path['rewards']

        all_states.append(observations)
        all_actions.append(actions)
        all_rtg.append(rtg)

    all_states = np.concatenate(all_states, 0)
    all_actions = np.concatenate(all_actions, 0)
    all_rtg = np.concatenate(all_rtg, 0)
    



    state_mean = all_states.mean(0, keepdims=True)
    state_stddev =  np.sqrt(all_states.std(0, keepdims=True))
    all_states = (all_states - state_mean)/ np.clip(state_stddev, EPS, 1e7)

    min_action = np.min(all_actions, 0)
    max_action = np.max(all_actions, 0)
    action_range = max_action - min_action
    if conf.act_norm == 'min_max':
        min_action -= conf.uniform_buffer * action_range
        max_action += conf.uniform_buffer * action_range

        norm_act = MinMaxNormalizationLayer(min_action, max_action)
        denorm_act = MinMaxDenormalizationLayer(min_action, max_action)
        all_actions = norm_act(all_actions)
    else:
        denorm_act = lambda a: a



    if conf.rtg_norm == 'norm_tanh':
        all_rtg = (all_rtg - all_rtg.mean()) / np.clip(np.sqrt(all_rtg.std()), EPS, 1e7)
        all_rtg = np.tanh(all_rtg)
    elif conf.rtg_norm == 'cdf': 
        all_rtg = np.array([(rtg >= all_rtg).mean() for rtg in all_rtg])
        all_rtg *= 2
        all_rtg -= 1
    elif conf.rtg_norm == 'min_max':
        min_rtg = all_rtg.min()
        max_rtg = all_rtg.max()
        range_rtg = max_rtg - min_rtg
        min_rtg -= conf.rtg_uniform_buffer * range_rtg
        max_rtg += conf.rtg_uniform_buffer * range_rtg
        norm_rtg = MinMaxNormalizationLayer(min_rtg, max_rtg)
        all_rtg = norm_rtg(all_rtg)
    else:
        raise ValueError('Not such rtg normalization')

    all_rtg = np.clip(all_rtg, -1, 1)
    all_rtg = all_rtg.reshape(-1, 1)

    num_timesteps = all_states.shape[0]

    algo = algos[conf.algo]
    agent = algo(state_dims=state_dim, 
                 actions_dims=act_dim, 
                 denorm_act=denorm_act,
                 **conf.dict())
    agent_state = agent.get_init_state(next(rng_seq), learning_rate=conf.lr)

    print("=" * 50)
    print("Starting new experiment:")
    print(f"Dataset: {name}")
    print(f"Method: {conf.algo}")
    print("-" * 25)
    print("=" * 50)

    def get_batch(batch_size):
        idx = np.random.choice(all_states.shape[0], batch_size, replace=True)
        s = all_states[idx]
        a = all_actions[idx]
        rtg = all_rtg[idx]
        return s, a, rtg

    def load_wab():
        with open(wab_info, "r") as wab:
            wab_id = wab.read()
        return wab_id

    def save_wab(wab_id):
        with open(chkp_path / "wab.sav", "w") as f:
            f.write(wab_id)

    def save_chkp(state) -> None:
        with open(chkp_path / "arrays.npy", "wb") as f:
            for x in jax.tree_leaves(state):
                np.save(f, x, allow_pickle=False)

        tree_struct = jax.tree_map(lambda t: 0, state)
        with open(chkp_path / "tree.pkl", "wb") as f:
            pickle.dump(tree_struct, f)

        print("Checkpoint saved.")

    def load_chkp(agent_state):
        tree_struct = chkp_path / "tree.pkl"
        if tree_struct.exists() and tree_struct.stat().st_size > 0:
            with open(tree_struct, "rb") as f:
                tree_struct = pickle.load(f)

        leaves, treedef = jax.tree_flatten(tree_struct)
        arrays = chkp_path / "arrays.npy"
        if arrays.exists() and arrays.stat().st_size > 0:
            with open(arrays, "rb") as f:
                flat_state = [np.load(f) for _ in leaves]

            agent_state = jax.tree_unflatten(treedef, flat_state)
            print(f"> Loaded state from {exp_prefix}")
            return agent_state
        else:
            print(f"> No state file found; starting new training")
            return agent_state

    def train(agent_state) -> Dict:
        #nonlocal agent_state
        nonlocal rng_seq
        with tqdm(range(conf.num_steps_per_iter), unit="batch") as titer:
            for iter in titer:
                states, actions, rtg = get_batch(conf.batch_size)
                titer.set_description(f"Train iteration {iter}")
                batch = {"states": states, "actions": actions, "rtg": rtg}
                agent_state, stats = agent.sgd_step(next(rng_seq), agent_state, batch)
                mse = (stats['MSE a_t'] + stats['MSE g_t']) / 2
                titer.set_postfix(loss_mean=stats["loss"],
                                  mse=mse,
                                  refresh=False)

        return agent_state, stats

    def evaluate(agent_state, alpha) -> Dict[str, Any]:
        returns, returns_norm, lengths, lengths_norm, avg_rtg_preds = [], [], [], [], []

        with tqdm(range(conf.num_eval_episodes), unit="eval") as titer:
            for iter in titer:
                ret, length, avg_rtg_pred = evaluate_episode_rtg(
                    conf.seed,
                    agent_state,
                    env,
                    act_dim,
                    agent,
                    alpha=alpha,
                    max_ep_len=10000,
                    state_mean=state_mean,
                    state_stddev=state_stddev,
                )

                if iter % 10 == 0:
                    titer.set_description(f"Eval iteration {iter}")
                norm_ret = env.get_normalized_score(ret) * 100
                norm_len = env.get_normalized_score(length) * 100
                returns_norm.append(norm_ret)
                avg_rtg_preds.append(avg_rtg_pred)
                lengths_norm.append(norm_len)
                returns.append(ret)
                lengths.append(length)
                titer.set_postfix(norm_ret_mean=np.mean(norm_ret))

        return agent_state, {
            f"evaluation/return_mean_norm": np.mean(returns_norm),
            f"evaluation/return_std_norm": np.std(returns_norm),
            f"evaluation/return_mean": np.mean(returns),
            f"evaluation/return_std": np.std(returns),
            'alpha': alpha,
            'evaluation/avg_rtg_pred': np.mean(avg_rtg_preds),
        }

    if wab_log:
        wab_info = chkp_path / "wab.sav"
        if wab_info.exists() and wab_info.stat().st_size > 0:
            wab_id = load_wab()
            print(f"> Loaded wandb id ({wab_id}) from {wab_info}")
        else:
            wab_id = wandb.util.generate_id()
            save_wab(wab_id)
            print(f"> New wandb id generated: {wab_id}")
        wandb.init(
            name=exp_prefix, group=group_name, project=conf.wandb_project, entity=conf.wandb_entity, config=conf
        )

    def run(agent_state) -> None:
        agent_state = load_chkp(agent_state)
        logs = {}
        while agent_state.step < (conf.max_iters * conf.num_steps_per_iter):
            agent_state, train_out = train(agent_state=agent_state)
            logs.update(train_out)
            if conf.save_ckpt:
                save_chkp(state=agent_state)
            if wab_log:
                wandb.log(logs)

        logs = {}
        alphas = [0., 0.25, 0.5, 0.75, 1., 2.5, 5., 10.] if (conf.alpha < 0) else [conf.alpha]
        for alpha in alphas:
            agent_state, eval_out = evaluate(agent_state, alpha)
            logs.update(eval_out)
            if wab_log:
                wandb.log(logs)


    experiment.run = run

    return experiment, agent_state


# environment variables will always take priority over values loaded from a dotenv file
# consider this when working with job.yaml
class Settings(BaseSettings):
    env_name: str = Field(..., description="D4RL env name")
    dataset: str
    ds_version: str
    wandb_project: str
    wandb_entity: str
    algo: Any
    loss_type: str
    act_norm: str
    rtg_norm: str
    activation: str
    lr: float
    alpha: float
    scale: float
    density_penalty: float
    seed: int
    batch_size: int
    num_eval_episodes: int
    max_iters: int
    num_layers: int
    dims: int
    num_steps_per_iter: int
    num_mcmc_chains: int
    num_action_samples: int
    medium_level_traj: int
    expert_level_traj: int
    warmup: int
    weight_decay: float
    max_g: float
    ema: float
    eta: float
    temperature: float
    grad_penalty: float
    uniform_buffer: float
    rtg_uniform_buffer: float
    gradient_scaling: float
    use_skip: bool
    spectral_norm: bool
    filter_: bool
    init_g_randomly: bool
    use_bias: bool
    use_layer_norm: bool
    average_reward: bool
    u_net: bool
    wab_log: bool
    lr_schedule: bool
    all_grad_penalty: bool
    save_ckpt: bool

    class Config:
        env_prefix = ""
        case_sentive = False
        env_file = ".env"
        env_file_encoding = "utf-8"


if __name__ == "__main__":
    conf = Settings()
    exp, agent_state = experiment(conf=conf)
    exp.run(agent_state)
