from utils.reporter import get_reporter, get_reporter_dir
from os import path
import jax
import jax.random as jrnd
import jax.numpy as jnp
from envs import make_env
from algos import make_algo
from typing import cast, Dict, Any, Optional
from utils.jax_util import extract_in_out_trees, build_merge_trees
from utils.saver import saver
import numpy as np
import torch
from pathlib import Path
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
import logging
from utils.progress import tqdm
from omegaconf import DictConfig
from utils.torch import take_per_row
from utils.time import Timeit
from torch.utils.data import Dataset as PthDataset
from torch.utils.data import DataLoader
from accelerate import Accelerator


def jax_to_torch(array_dtypes, device):
    return [torch.tensor(np.array(a), device=device, dtype=d) for a, d in array_dtypes]


class CustomDataset(PthDataset):
    def __init__(self, dataset, cfg, dcfg):
        super().__init__()
        self._dataset = dataset
        self.i1 = None
        self.i2 = None
        self.cfg = cfg
        self.dcfg = dcfg

    def __len__(self):
        # return self.dcfg["training_steps"]
        return self.cfg.training_steps

    def __getitem__(self, idx):

        # sequential sampling for training
        if self.cfg.algo.name in ["ed", "xl"]:
            if self.i2 is not None and (self.i2 == self.dcfg["steps_trained"]).all():
                self.i1 = None
                self.i2 = None

            if self.i1 is None and self.i2 is None:
                self.i1 = torch.randint(
                    0,
                    self._dataset.num_train_envs,
                    (self.cfg.algo.batch_size,),
                    device=self.cfg.device,
                )
                self.i2 = torch.zeros(
                    (self.cfg.algo.batch_size,),
                    dtype=torch.int32,
                    device=self.cfg.device,
                )

            trajs, (_, i2) = self._dataset.sample(self.i1, self.i2)
            self.i2 = i2 + self.cfg.algo.ctx_len
        else:
            # random sampling row-wisely
            trajs, (self.i1, self.i2) = self._dataset.sample()
            # assert self.i1 is None and self.i2 is None

        return trajs, (self.i1, self.i2)


class Dataset:
    def __init__(
        self,
        obs,
        acts,
        rwds,
        terms,
        _truncs,
        batch_size: int,
        cfg: DictConfig,
        dcfg: Dict[str, Any],
    ) -> None:
        device = cfg.device
        self.obs = torch.as_tensor(
            obs,
            dtype=torch.int32 if dcfg["obs_is_concrete"] else torch.float32,
            device=device,
        )
        # remove extra state (has_key_or_not) from obs
        if cfg.env.name == "dark_key_to_door":
            self.obs = self.obs[..., :2]
        self.acts = torch.as_tensor(
            acts,
            dtype=torch.int32 if dcfg["act_is_concrete"] else torch.float32,
            device=device,
        )
        self.rwds = torch.as_tensor(
            rwds,
            dtype=torch.int32 if dcfg["rwd_is_concrete"] else torch.float32,
            device=device,
        )
        if not dcfg["rwd_is_concrete"]:
            self.rwds = self.rwds.unsqueeze(-1)
        self.terms = torch.as_tensor(terms, dtype=torch.int32, device=device)
        self.device = device
        # assert self.obs.shape[-1] == dcfg['']

        self.eval_ctx_length = cfg.algo.ctx_len

        assert (
            self.acts.shape[:2]
            == self.rwds.shape[:2]
            == (cfg.num_train_envs, dcfg["steps_trained"])
        )
        self.batch_size = batch_size
        # self.total_samples = cfg.num_train_envs * (
        #     dcfg["training_steps"] - self.eval_ctx_length + 1
        # )
        self.num_train_envs = cfg.num_train_envs
        self.cfg = cfg
        self.dcfg = dcfg

    def sample(
        self,
        i1: Optional[torch.Tensor] = None,
        i2: Optional[torch.Tensor] = None,
    ):
        if i1 is None and i2 is None:
            i1 = torch.randint(
                0, self.cfg.num_train_envs, (self.batch_size,), device=self.device
            )
            i2 = torch.randint(
                0,
                self.dcfg["steps_trained"] - self.eval_ctx_length + 1,
                (self.batch_size,),
                device=self.device,
            )

            assert i1.ndim == i2.ndim == 1

        assert i1 is not None and i2 is not None

        return (
            take_per_row(self.obs[i1], i2, self.eval_ctx_length),
            take_per_row(self.acts[i1], i2, self.eval_ctx_length),
            take_per_row(self.rwds[i1], i2, self.eval_ctx_length),
            take_per_row(self.terms[i1], i2, self.eval_ctx_length),
            i2.unsqueeze(1) + torch.arange(0, self.eval_ctx_length, device=self.device),
        ), (i1, i2)


def train_handle(cfg: DictConfig, dcfg: Dict[str, Any]):

    TB_DIR = get_reporter_dir()
    output_dir = HydraConfig.get().runtime.output_dir
    MODEL_DIR = path.abspath(f"{output_dir}/models/")

    key = jax.random.PRNGKey(cfg.seed)

    eval_envs_states = {}
    for k in ["eval"]:
        key, env_key = jrnd.split(key)
        eval_env, _env_state, _env_tree = make_env(
            cfg.env.name, cfg, dcfg, "eval", env_key
        )

        eval_envs_states[k] = _env_state

    load_env_sys = saver.restore(
        f"{cfg.load_path}/eval_env_sys",
        item={k: v.sys for k, v in eval_envs_states.items()},
    )

    eval_envs_states = {
        k: v.replace(sys=load_env_sys[k]) for k, v in eval_envs_states.items()
    }
    del load_env_sys
    print({k: v.sys for i, (k, v) in enumerate(eval_envs_states.items())})

    dataset = Dataset(
        *(torch.load(f"{cfg.load_path}/dataset/ds.pth", map_location=cfg.device)),
        batch_size=cfg.algo.batch_size,
        cfg=cfg,
        dcfg=dcfg,
    )

    key, algo_key = jrnd.split(key)
    algo, _ = make_algo(
        cfg.algo.name,
        cfg,
        {
            **dcfg,
            "state_dim": eval_env.unwrapped.state_dim,
            "act_dim": eval_env.unwrapped.act_dim,
        },
        algo_key,
    )

    merge_trees = jax.jit(build_merge_trees(_env_tree))

    dataloader = DataLoader(CustomDataset(dataset, cfg, dcfg), batch_size=1)
    accelerator = Accelerator(mixed_precision=cfg.precision, cpu=cfg.device == "cpu")
    algo.model, algo.optim, dataloader, algo.schdler = accelerator.prepare(
        algo.model, algo.optim, dataloader, algo.scheduler
    )

    def reset_env(key, env_state):
        _reset_keys = jrnd.split(key, cfg.num_eval_envs)
        _reset_env_state = eval_env.reset(jnp.array(_reset_keys), env_state)
        _need_reset = jnp.logical_or(
            env_state.terms,
            env_state.truncs,
        )
        _reset_tree = jax.tree_map(
            lambda x: jnp.array(0) if x is None else _need_reset,
            _env_tree,
            is_leaf=lambda x: x is None,
        )
        env_state = merge_trees(
            _reset_tree,
            env_state,
            _reset_env_state,
        )
        return env_state

    @jax.jit
    def _reset_env(key, env_state):

        env_state = jax.lax.cond(
            jnp.any(jnp.logical_or(env_state.terms, env_state.truncs)),
            lambda c: reset_env(key, c),
            lambda c: c,
            env_state,
        )
        return env_state

    def eval_one_step(env_state, pos, key, eval_mems):

        obs = torch.tensor(
            np.array(env_state.obs),
            dtype=torch.int32 if dcfg["obs_is_concrete"] else torch.float32,
            device=cfg.device,
        )
        # remove extra state
        if cfg.env.name == "dark_key_to_door":
            obs = obs[..., :2]

        algo.enroll_obs(obs, pos)
        # with Timeit("make action"):
        if cfg.algo.name in ["ed", "xl"]:
            acts, _new_mems = algo.make_action(emb_mems=eval_mems)
            new_mems = _new_mems

        else:
            acts, _ = algo.make_action(emb_mems=None)
            new_mems = None

        env_state = jax.device_get(eval_env.step(env_state, acts.numpy(force=True)))

        (acts, rwd, term, truncs) = jax_to_torch(
            (
                (
                    env_state.acts,
                    torch.int32 if dcfg["act_is_concrete"] else torch.float32,
                ),
                (
                    env_state.rwds,
                    torch.int32 if dcfg["rwd_is_concrete"] else torch.float32,
                ),
                (env_state.terms, torch.int32),
                (env_state.truncs, torch.int32),
            ),
            cfg.device,
        )
        rwd = rwd.unsqueeze(-1) if not dcfg["rwd_is_concrete"] else rwd
        algo.enroll_rest(acts, rwd, term, pos)

        key, reset_key = jrnd.split(key)
        env_state = _reset_env(reset_key, env_state)

        return rwd.clone(), env_state, key, torch.logical_or(term, truncs), new_mems

    def eval_all(key, env_state):

        key, *reset_keys = jrnd.split(key, 1 + cfg.num_eval_envs)
        env_state = eval_env.reset(jnp.array(reset_keys), env_state)
        rewards = torch.zeros(
            (cfg.num_eval_envs, dcfg["steps_trained"]),
            dtype=torch.float32,
            device=cfg.device,
        )
        dones = torch.zeros(
            (cfg.num_eval_envs, dcfg["steps_trained"]),
            dtype=torch.int32,
            device=cfg.device,
        )
        if cfg.algo.name in ["xl", "ed"]:
            eval_mems = None
        algo.eval()
        algo.reset()

        for i in tqdm(range(dcfg["steps_trained"]), desc="eval"):
            # with Timeit("eval one whole step"):
            _rewards, env_state, key, _dones, new_eval_mems = eval_one_step(
                env_state,
                torch.ones((cfg.num_eval_envs,), device=cfg.device, dtype=torch.int32)
                * i,
                key,
                eval_mems if cfg.algo.name in ["xl", "ed"] else None,
            )
            rewards[:, i] = (
                _rewards.squeeze(-1) if not dcfg["rwd_is_concrete"] else _rewards
            )
            dones[:, i] = _dones
            if cfg.algo.name in ["xl", "ed"] and algo.ctx.is_full:
                # reset context to make it step as a whole block
                algo.reset()
                if eval_mems is None:
                    eval_mems = new_eval_mems
                else:
                    eval_mems = [
                        torch.cat((eval_mems[i], new_eval_mems[i]), dim=1)[
                            :, -cfg.algo.mem_len * 4 :
                        ]
                        for i in range(cfg.algo.layers)
                    ]

        algo.train()
        return rewards, dones

    def train_one_step(trajs, emb_mems):
        algo.train()

        _rlt = algo.update(
            [t.squeeze(0) for t in trajs], accelerator, emb_mems=emb_mems
        )

        return _rlt

    key, eval_key = jrnd.split(key)
    rewards, dones = eval_all(eval_key, eval_envs_states["eval"])

    print(f"reward at 0: {rewards.mean(0).sum()}", flush=True)
    print(f"dones at 0: {dones.sum()}", flush=True)

    Path(f"{MODEL_DIR}/rlts/0").mkdir(exist_ok=True, parents=True)
    torch.save((rewards, dones), f"{MODEL_DIR}/rlts/0/all_results.pth")

    _data_iter = iter(dataloader)

    if cfg.algo.name in ["xl", "ed"]:
        assert dcfg["steps_trained"] % cfg.algo.ctx_len == 0
    if cfg.save_algo:
        Path(f"{MODEL_DIR}/model").mkdir(exist_ok=True, parents=True)
        torch.save(algo.model.save(), f"{MODEL_DIR}/model/0.pth")
    for i in tqdm(range(1, 1 + cfg.split_nums), desc="total training"):
        with Timeit("train for 1 epoch"):
            if cfg.algo.name in ["xl", "ed"]:
                emb_mems = None

            for _ in tqdm(range(dcfg["frames_per_split"]), desc="single train split"):
                _data, (i1, i2) = next(_data_iter)
                _rlt, _new_emb_mem = train_one_step(
                    _data, emb_mems if cfg.algo.name in ["xl", "ed"] else None
                )

                _loss = _rlt
                get_reporter().add_scalars(dict(loss=_loss), "train")

                if cfg.algo.name in ["xl", "ed"]:
                    if (i2 == dcfg["steps_trained"]).all():
                        emb_mems = None
                    else:
                        if emb_mems is None:
                            emb_mems = _new_emb_mem
                        else:
                            emb_mems = [
                                torch.cat((emb_mems[i], _new_emb_mem[i]), dim=1)[
                                    :, -cfg.algo.mem_len * 4 :
                                ]
                                for i in range(cfg.algo.layers)
                            ]

        if cfg.algo.name in ["xl", "ed"]:
            emb_mems = None
            algo.reset()

        with Timeit("eval for 1 epoch"):
            key, eval_key = jrnd.split(key)
            rewards, dones = eval_all(eval_key, eval_envs_states["eval"])

        print(f"reward at {i}: {rewards.mean(0).sum()}", flush=True)
        print(f"dones at {i}: {dones.sum()}", flush=True)

        Path(f"{MODEL_DIR}/rlts/{i}").mkdir(exist_ok=True, parents=True)
        torch.save((rewards, dones), f"{MODEL_DIR}/rlts/{i}/all_results.pth")
        if cfg.save_algo:
            torch.save(algo.model.save(), f"{MODEL_DIR}/model/{i}.pth")
