from utils.trainer import OnlineTrainer, OfflineTrainer
from typing import Dict, Any, cast
from utils.reporter import get_reporter
import json
from utils.env import seed
from agent.args import args, DEVICE1, DEVICE2, DEVICE3
from agent.algm import Robert
from agent.config import CONFIGS
from os import path
import torch
from agent.params import PARAMS
from agent.common import NEW_DATASET_FOLDER
from robert.agent.dataprocess import (
    preprocess,
    dataset_split2,
    enumer2,
    calc_state_seq_mean_std,
)
from envs.env import dmc
from utils.common import get_device
import math


DATASET_FOLDER = path.abspath(path.join(path.dirname(__file__), "../datasets"))

VALID_SIZE = 6


# %%
def main(args: Dict[str, Any]):
    torch.cuda.empty_cache()
    exp_name = json.dumps({**args, **PARAMS, **CONFIGS}, indent=4, sort_keys=True)
    RANDOM_SEED = args["seed"]
    seed(RANDOM_SEED)

    reporter, logdir = get_reporter(args["name"], exp_name)

    static_start = 50
    if args["task"] == "shadow":
        env_name = "shadowhand_dummy"
    elif args["task"] == "pointmaze":
        env_name = "point_mass_maze_reach_bottom_left"
        static_start = 20
    elif args["task"] == "ur5e":
        env_name = "ur5e_dummy"
    else:
        raise ValueError(f"unrecognized task: {args['task']}")

    test_env = dmc.make(
        env_name,
        seed=RANDOM_SEED + 50000,
        static_start=static_start,
        env_args={"targets": ["0" for _ in range(6)]}
        if args["task"] == "pointmaze"
        else None,
    )
    print(f"use {env_name} environment")

    read_dir = f'{NEW_DATASET_FOLDER}/{args["task"]}'
    states = torch.load(f"{read_dir}/states_train.pt")
    actions = torch.load(f"{read_dir}/actions_train.pt")

    assert states.shape == (
        CONFIGS["dataset"]["episode_num"] - VALID_SIZE,
        CONFIGS["dataset"]["episode_length"],
        CONFIGS["dataset"]["episode_steps"],
        CONFIGS["state_dim"],
    )
    assert actions.shape == (
        CONFIGS["dataset"]["episode_num"] - VALID_SIZE,
        CONFIGS["dataset"]["episode_length"],
        CONFIGS["dataset"]["episode_steps"] - 1,
        CONFIGS["action_dim"],
    )

    shrink = args["shrink"]
    eval_ratio = args["eval_ratio"]
    crop_ratio = args["crop_ratio"]

    states = states[::shrink]
    actions = actions[::shrink]
    states, actions = preprocess(states, actions)

    if crop_ratio >= 0.05:
        crop_num = math.floor(states.size(0) * crop_ratio)
        print(f"crop dataset by {crop_ratio}% and num. of {crop_num}")

        _muggle_idx = torch.randperm(states.size(0))[:crop_num]
        states[_muggle_idx] = states[_muggle_idx[0]].clone()
        reporter.add_text("muggle_idx", str(_muggle_idx.shape))
        actions[_muggle_idx] = actions[_muggle_idx[0]].clone()

    raw_actions = actions.clone()
    states = states.to(DEVICE1)
    actions = actions.to(DEVICE2)

    print(
        f"in total, states shape are: {states.shape}, actions shape are: {actions.shape}"
    )
    reporter.add_text("states shape", str(states.shape))
    reporter.add_text("actions shape", str(actions.shape))

    (valid_states, valid_actions) = (
        torch.load(f"{read_dir}/states_valid.pt"),
        torch.load(f"{read_dir}/actions_valid.pt"),
    )
    (valid_half_idx, valid_full_idx) = (
        torch.load(f"{read_dir}/valid_half_masks.pt"),
        torch.load(f"{read_dir}/valid_full_masks.pt"),
    )

    assert (
        valid_states.size(0)
        == valid_actions.size(0)
        == VALID_SIZE
        == valid_half_idx.size(0)
        == valid_full_idx.size(0)
    )

    reporter.add_text("states length", str(states.shape))
    reporter.add_text("actions length", str(actions.shape))
    reporter.add_text("valid states length", str(valid_states.shape))
    reporter.add_text("valid actions length", str(valid_actions.shape))

    valid_states = valid_states[
        :, : eval_ratio * PARAMS["post_steps"] + PARAMS["post_steps"]
    ].to(DEVICE3)
    valid_actions = valid_actions[
        :, : eval_ratio * PARAMS["post_steps"] + PARAMS["post_steps"] - 1
    ].to(DEVICE3)

    action_dim = CONFIGS["action_dim"]
    action_mean = actions.reshape((-1, action_dim)).mean(dim=0)
    action_std = actions.reshape((-1, action_dim)).std(dim=0) + 1e-8

    actions.sub_(action_mean).div_(action_std)
    valid_actions.sub_(action_mean.to(DEVICE3)).div_(action_std.to(DEVICE3))
    raw_actions.sub_(action_mean.cpu()).div_(action_std.cpu())
    print("action norm finish")

    state_seq_len = PARAMS["pre_steps"] + PARAMS["post_steps"] + 1
    action_seq_len = PARAMS["pre_steps"] + PARAMS["post_steps"]

    states = enumer2(states, state_seq_len)
    actions = enumer2(actions, action_seq_len)

    state_seq_mean, state_seq_std = calc_state_seq_mean_std(states)

    print("data preprocess finish")

    trainer = OfflineTrainer(
        Robert(
            "Robert",
            {**CONFIGS, "log_dir": logdir},
            {
                **PARAMS,
                "no_mask": args["no_mask"],
                "model_kind": args["model_kind"],
                "eval_ratio": eval_ratio,
                "seed": RANDOM_SEED,
                "state_seq_mean": state_seq_mean.to(DEVICE3),
                "state_seq_std": state_seq_std.to(DEVICE3),
                "action_mean": action_mean.to(DEVICE3),
                "action_std": action_std.to(DEVICE3),
                "train_raw_actions": raw_actions,
            },
        ),
        with_reporter=reporter,
    )
    del raw_actions
    torch.cuda.empty_cache()

    valid_states = valid_states[:, 1:] - valid_states[:, [0]]
    valid_commands = {
        f"valid_task_{i}": valid_states[i].clone() for i in range(VALID_SIZE)
    }
    valid_commands_no_mask = {
        f"{k}-no-mask": v.clone() for i, (k, v) in enumerate(valid_commands.items())
    }
    valid_commands_half_mask = {
        f"{k}-half-mask": (
            v.clone(),
            valid_half_idx[i],
        )
        for i, (k, v) in enumerate(valid_commands.items())
    }

    valid_commands_full_mask = {
        f"{k}-full-mask": (v.clone(), valid_full_idx[i])
        for i, (k, v) in enumerate(valid_commands.items())
    }

    TOTAL_TRAINING_FRAMES = PARAMS["batch_size"] * PARAMS["train_iters"]
    trainer.train_and_eval(
        dict(
            train_states=states,
            train_actions=actions,
            no_masked_commands=valid_commands_no_mask,
            half_masked_commands=valid_commands_half_mask,
            full_masked_commands=valid_commands_full_mask,
            # train_raw_actions=raw_actions,
        ),
        test_env,
        RANDOM_SEED,
        single_train_frames=int(TOTAL_TRAINING_FRAMES / 100),
        total_train_frames=TOTAL_TRAINING_FRAMES,
    )
    reporter._writer.add_hparams(
        {**args, "n_params": trainer.algm.n_params},
        {"best_grades": trainer.algm.best_grades},
    )


if __name__ == "__main__":
    main(vars(args))
