from agent.args import args, DEVICE1, DEVICE2, DEVICE3
from utils.trainer import OnlineTrainer, OfflineTrainer
from utils.env import seed
from envs.env import dmc
from dm_control import viewer
from agent.algm import Robert
import numpy as np
import torch
from agent.params import PARAMS
from agent.config import CONFIGS
from agent.common import NEW_DATASET_FOLDER
from video import VideoRecorder
from os import path
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from typing import cast
from pathlib import Path
from dataclasses import dataclass

# from robert.agent.commands_pointmaze import COMMANDS, LONG
from robert.agent.commands import COMMANDS, LONG
from typing import Dict, Any
import json
from robert.agent.dataprocess import (
    extract_observation,
)

VALID_SIZE = 6

DEVICE = DEVICE1
DATASET_FOLDER = path.abspath(path.join(path.dirname(__file__), "../datasets"))
MODEL_PATH = "./model"


def random_policy(*args, **kwargs):
    return torch.as_tensor(
        np.random.uniform(-1, 1, size=CONFIGS["action_dim"]), dtype=torch.float32
    ).unsqueeze(0)


def zero_policy(*args, **kwargs):
    return torch.as_tensor(
        np.zeros((CONFIGS["action_dim"],)), dtype=torch.float32
    ).unsqueeze(0)


def nn_model(task):
    algm = Robert(
        "Robert",
        {**CONFIGS, "log_dir": None},
        {
            **PARAMS,
            "seed": args.seed,
            "no_mask": args.no_mask,
            "eval_ratio": args.eval_ratio,
            "model_kind": args.model_kind,
            "state_seq_mean": torch.zeros((1,)).to(DEVICE3),
            "state_seq_std": torch.zeros((1,)).to(DEVICE3),
            "action_mean": torch.zeros((1,)).to(DEVICE3),
            "action_std": torch.zeros((1,)).to(DEVICE3),
            "train_raw_actions": torch.zeros((15151, CONFIGS["action_dim"])),
        },
    )
    algm.load(f"{MODEL_PATH}/{task}", map_location="cuda:0")

    def take_action(mode, state, env, mask):
        nonlocal algm
        return algm.take_action(mode, state, env, mask)

    return take_action


@torch.no_grad()
def inspect(args):
    torch.cuda.empty_cache()
    exp_name = json.dumps({**args, **PARAMS, **CONFIGS}, indent=4, sort_keys=True)
    RANDOM_SEED = args["seed"]
    seed(RANDOM_SEED)

    static_start = 50
    if args["task"] == "shadow":
        env_name = "shadowhand_dummy"
    elif args["task"] == "pointmaze":
        env_name = "point_mass_maze_reach_bottom_left"
        static_start = 20
    elif args["task"] == "ur5e":
        env_name = "ur5e_dummy"
    else:
        raise ValueError(f"unrecognized task: ${args['task']}")

    eval_ratio = args["eval_ratio"]

    policy = dict(random=random_policy, zero=zero_policy, model=nn_model)[
        args["eval_policy"]
    ]
    if args["eval_policy"] == "model":
        policy = policy(args["task"])

    test_env = dmc.make(
        env_name,
        seed=RANDOM_SEED + 50000,
        static_start=static_start,
        env_args={"targets": ["0" for _ in range(6)]}
        if args["task"] == "pointmaze"
        else None,
    )
    print(f"use {env_name} environment")

    read_dir = f'{NEW_DATASET_FOLDER}/{args["task"]}'

    (valid_states, valid_actions) = (
        torch.load(f"{read_dir}/states_valid.pt"),
        torch.load(f"{read_dir}/actions_valid.pt"),
    )
    (valid_half_idx, valid_full_idx) = (
        torch.load(f"{read_dir}/valid_half_masks.pt"),
        torch.load(f"{read_dir}/valid_full_masks.pt"),
    )
    assert (
        valid_states.size(0)
        == valid_actions.size(0)
        == VALID_SIZE
        == valid_half_idx.size(0)
        == valid_full_idx.size(0)
    )

    valid_states = valid_states[
        :, : eval_ratio * PARAMS["post_steps"] + PARAMS["post_steps"]
    ].to(DEVICE3)
    valid_actions = valid_actions[
        :, : eval_ratio * PARAMS["post_steps"] + PARAMS["post_steps"] - 1
    ].to(DEVICE3)

    print("data preprocess finish")
    state_seq_len = PARAMS["pre_steps"] + PARAMS["post_steps"] + 1
    action_seq_len = PARAMS["pre_steps"] + PARAMS["post_steps"]

    assert valid_states.size(0) == VALID_SIZE
    valid_states = valid_states[:, 1:] - valid_states[:, [0]]
    valid_commands = {
        f"valid_task_{i}": valid_states[i].clone() for i in range(valid_states.size(0))
    }
    valid_commands_no_mask = {
        f"{k}-no-mask": v.clone() for i, (k, v) in enumerate(valid_commands.items())
    }
    valid_commands_half_mask = {
        f"{k}-half-mask": (
            v.clone(),
            valid_half_idx[i],
        )
        for i, (k, v) in enumerate(valid_commands.items())
    }

    valid_commands_full_mask = {
        f"{k}-full-mask": (v.clone(), valid_full_idx[i])
        for i, (k, v) in enumerate(valid_commands.items())
    }

    def eval(info: Dict[str, Any]):
        env = info["env"]
        action_spec = env.action_spec()

        no_masked_commands = info["no_masked_commands"]
        half_masked_commands = info["half_masked_commands"]
        full_masked_commands = info["full_masked_commands"]
        pre_steps, post_steps, state_dim = (
            PARAMS["pre_steps"],
            PARAMS["post_steps"],
            CONFIGS["state_dim"],
        )
        take_action = info["take_action"]
        mask_tensor = (
            torch.ones((state_dim,), dtype=torch.float32, device=DEVICE3) * -999
        )
        ALL_COMMANDS = {
            **COMMANDS,
            **no_masked_commands,
            **half_masked_commands,
            **full_masked_commands,
        }

        section_grades = {"no-mask": [], "full-mask": [], "half-mask": []}

        for name, _command in ALL_COMMANDS.items():
            # section_name = section[0]
            # _all_commands = section[1]
            # section_grades = []

            # for name, _command in _all_commands.items():
            time_step = env.reset()
            obs_history = []
            act_history = []
            section_name = (
                (name[cast(str, name).index("-") + 1 :]) if "-" in name else ""
            )
            assert section_name in ["full-mask", "half-mask", "no-mask"]

            if not isinstance(_command, tuple):
                command = _command[: (post_steps * eval_ratio + post_steps - 1)]
                mask_idx = None
                masked_command = command
                masks = None
            else:
                command = _command[0][: (post_steps * eval_ratio + post_steps - 1)]
                mask_idx = _command[1]
                mask_idx = mask_idx[
                    mask_idx < (post_steps * eval_ratio + post_steps - 1)
                ]
                assert mask_idx.dtype == torch.int64 and len(mask_idx.shape) == 1
                masks = torch.ones(
                    (post_steps * eval_ratio + post_steps - 1,),
                    dtype=torch.bool,
                )

                masks[mask_idx] = False
                masked_command = command.clone()
                # masked_command[masks] = torch.tensor(
                #     [-999, -999], dtype=torch.float32, device=DEVICE
                # )
                masked_command[masks] = mask_tensor
                # print(f"masked command is: {masked_command.tolist()}")

            coord_commands = (
                torch.as_tensor(
                    extract_observation(time_step),
                    dtype=torch.float32,
                    device=DEVICE,
                )
                + command
            )
            if masks is not None:
                masked_coord_commands = coord_commands.clone()
                masked_coord_commands[masks] = mask_tensor
            else:
                masked_coord_commands = coord_commands

            assert command.shape == (
                post_steps * eval_ratio + post_steps - 1,
                state_dim,
            )

            def act(t):
                nonlocal obs_history, act_history, masks
                if len(obs_history) < pre_steps:
                    action = np.zeros(action_spec.shape)
                else:
                    history = torch.stack(obs_history[-pre_steps:], dim=0)
                    assert len(history) == pre_steps
                    current = torch.as_tensor(
                        extract_observation(t),  # .observation[:2],
                        dtype=torch.float32,
                        device=DEVICE,
                    ).unsqueeze(0)

                    input_corr = masked_coord_commands[
                        (len(obs_history) - pre_steps) : (
                            len(obs_history) - pre_steps + post_steps
                        )
                    ]

                    assert input_corr.shape == (post_steps, state_dim)

                    if masks is not None:
                        _masks = torch.cat(
                            (
                                torch.zeros(
                                    (pre_steps + 1,),
                                    dtype=torch.bool,
                                    # device=DEVICE,
                                ),
                                masks[
                                    (len(obs_history) - pre_steps) : (
                                        len(obs_history) - pre_steps + post_steps
                                    )
                                ],
                            ),
                            dim=0,
                        )
                        _masks = _masks.unsqueeze(0)
                        assert _masks.shape == (1, state_seq_len), _masks.shape
                    else:
                        _masks = None

                    action = take_action(
                        "eval",
                        torch.cat(
                            (history, current, input_corr),
                            dim=0,
                        ).unsqueeze(0),
                        env,
                        mask=_masks,
                    ).numpy(force=True)[0]

                act_history.append(action)
                return action

            while len(obs_history) < pre_steps + post_steps * eval_ratio:
                assert not time_step.last()

                old_timestep = time_step
                a = act(time_step)
                time_step = env.step(a)
                obs_history.append(
                    torch.as_tensor(
                        extract_observation(old_timestep),  # .observation[:2],
                        dtype=torch.float32,
                        device=DEVICE,
                    )
                )
            obs_history.append(
                torch.as_tensor(
                    extract_observation(time_step),  # .observation[:2],
                    dtype=torch.float32,
                    device=DEVICE,
                )
            )

            assert len(obs_history) == pre_steps + post_steps * eval_ratio + 1

            achieved = torch.stack(obs_history[-post_steps * eval_ratio :])
            assert len(achieved) == post_steps * eval_ratio
            if masks is not None:
                masked_achieved = (
                    achieved[torch.logical_not(masks[: post_steps * eval_ratio])]
                    - coord_commands[: post_steps * eval_ratio][
                        torch.logical_not(masks[: post_steps * eval_ratio])
                    ]
                )
            else:
                masked_achieved = achieved - coord_commands[: post_steps * eval_ratio]

            # print(f"{name} command result")
            # print("command")
            # print(f"{masked_command.tolist()}")
            # print("achieved")
            # print(f"{achieved.tolist()}")
            # print("action_history")
            # print(f"{np.vstack(act_history).tolist()}")
            # print("obs_history")
            # print(f"{torch.stack(obs_history).tolist()}")
            # print("coordinates_command", flush=True)
            # print(f"{coord_commands.tolist()}")
            score = masked_achieved.abs().mean().item()
            section_grades[section_name].append(score)
            # section_grades.append((masked_achieved).abs().mean().item())

            # assert len(section_grades) == len(_all_commands)
            # grades.append(np.mean(section_grades))

        assert sum([len(sg) for sg in section_grades.values()]) == len(ALL_COMMANDS)
        assert len(section_grades) == 3

        print(
            f"in {len(section_grades['no-mask'])} no-mask, mean grade is: {np.mean(section_grades['no-mask'])}"
        )
        print(
            f"in {len(section_grades['full-mask'])} full-mask, mean grade is: {np.mean(section_grades['full-mask'])}"
        )
        final_grades = np.mean([np.mean(sg) for sg in section_grades.values()])

        print(f"final grades are: {final_grades}")

    info = dict(
        commands=valid_commands,
        no_masked_commands=valid_commands_no_mask,
        half_masked_commands=valid_commands_half_mask,
        full_masked_commands=valid_commands_full_mask,
        env=test_env,
        take_action=policy,
    )
    eval(info)


@torch.no_grad()
def view(args):
    torch.cuda.empty_cache()
    exp_name = json.dumps({**args, **PARAMS, **CONFIGS}, indent=4, sort_keys=True)
    RANDOM_SEED = args["seed"]
    seed(RANDOM_SEED)

    static_start = 50
    targets = None
    if args["task"] == "shadow":
        env_name = "shadowhand_dummy"
    elif args["task"] == "pointmaze":
        env_name = "point_mass_maze_reach_bottom_left"
        static_start = 20
        if args["eval_name"] == "stay":
            camera = ["0", "-0.2", "0.25", "0", "0", "0.02"]
        # targets = ["0 0 0.01"]
        elif args["eval_name"] == "heart":
            camera = ["0", "-1.3", "2.5", "0", "-0.65", "0.1"]
        elif args["eval_name"] == "triangles":
            camera = ["0.92", "-0.5", "2.5", "0.92", "0", "0.1"]
        else:
            assert args["eval_name"] == "circle"
            camera = ["-0.2", 0, "0.95", "-0.2", "0", "0.05"]

        _targets = LONG[args["eval_name"]].tolist()

        targets = camera + [f"{t[0]} {t[1]} 0.01" for t in _targets]

    elif args["task"] == "ur5e":
        env_name = "ur5e_dummy"
    else:
        raise ValueError(f"unrecognized task: ${args['task']}")

    policy = nn_model(args["task"])

    test_env = dmc.make(
        env_name,
        seed=RANDOM_SEED + 50000,
        static_start=static_start,
        env_args={"targets": targets},
    )
    print(
        f"use {env_name} environment",
    )

    read_dir = f'{NEW_DATASET_FOLDER}/{args["task"]}'
    fake_eval = args["fake_eval"]

    print("data preprocess finish")
    state_seq_len = PARAMS["pre_steps"] + PARAMS["post_steps"] + 1
    action_seq_len = PARAMS["pre_steps"] + PARAMS["post_steps"]

    def eval(info: Dict[str, Any]):
        env = info["env"]
        action_spec = env.action_spec()

        commands, length, name = info["commands"], info["rollout_length"], info["name"]
        pre_steps, post_steps, state_dim = (
            PARAMS["pre_steps"],
            PARAMS["post_steps"],
            CONFIGS["state_dim"],
        )
        take_action = info["take_action"]
        _ = take_action(
            "eval",
            torch.zeros((state_seq_len, state_dim), device=DEVICE).unsqueeze(0),
            env,
            mask=None,
        ).numpy(force=True)[0]
        mask_tensor = (
            torch.ones((state_dim,), dtype=torch.float32, device=DEVICE3) * -999
        )
        path_dir = f"./eval_video/{args['task']}/{name}"
        Path(path_dir).mkdir(parents=True, exist_ok=True)
        video_recorder = VideoRecorder(
            Path(path_dir),
            render_size=640,
            camera_id=0,
        )
        recorders = [video_recorder]
        if args["task"] != "pointmaze":
            recorders += [
                VideoRecorder(
                    Path(path_dir),
                    render_size=640,
                    camera_id=1,
                ),
                VideoRecorder(
                    Path(path_dir),
                    render_size=640,
                    camera_id=2,
                ),
            ]

        cnt = 0
        tracking_errors = []
        for section in [
            ("long", commands),
        ]:
            _all_commands = section[1]

            for _name, _command in _all_commands.items():
                if _name != name:
                    continue
                time_step = env.reset()
                obs_history = []
                act_history = []
                to_be_achieved = None
                command_pos = []
                achieved_pos = []

                if not isinstance(_command, tuple):
                    command = _command
                    mask_idx = None
                    masked_command = command
                    masks = None
                else:
                    command = _command[0]
                    mask_idx = _command[1]
                    assert mask_idx.dtype == torch.int64 and len(mask_idx.shape) == 1
                    masks = torch.ones((command.size(0),), dtype=torch.bool)
                    masks[mask_idx] = False
                    masked_command = command.clone()
                    masked_command[masks] = mask_tensor

                coord_commands = (
                    torch.as_tensor(
                        extract_observation(time_step),
                        dtype=torch.float32,
                        device=DEVICE,
                    )
                    + command
                )
                if masks is not None:
                    masked_coord_commands = coord_commands.clone()
                    masked_coord_commands[masks] = mask_tensor
                else:
                    masked_coord_commands = coord_commands

                def act(t):
                    nonlocal obs_history, act_history, masks, to_be_achieved, tracking_errors, cnt, env, achieved_pos, command_pos
                    length = None
                    if name == "heart" or (
                        args["task"] == "ur5e"
                        and not cast(str, args["eval_name"]).startswith("rotate")
                    ):
                        length = int(23.5 * 20)
                    else:
                        length = int(11.5 * 20)
                    if len(obs_history) < pre_steps:
                        action = np.zeros(action_spec.shape)
                    else:
                        if len(obs_history) == pre_steps:
                            [r.init(env) for r in recorders]
                        if len(obs_history) > pre_steps:
                            # video_recorder.record(env)
                            [r.record(env) for r in recorders]
                        history = torch.stack(obs_history[-pre_steps:], dim=0)
                        assert len(history) == pre_steps
                        current = torch.as_tensor(
                            extract_observation(t),  # .observation[:2],
                            dtype=torch.float32,
                            device=DEVICE,
                        ).unsqueeze(0)
                        if len(obs_history) >= pre_steps:
                            achieved_pos.append(current)

                        if cnt == length:
                            # video_recorder.save(f"./{name}.mp4")
                            [
                                r.save(
                                    f"./{name}{'' if i == 0 else str(i)}{'' if not fake_eval else '_fake_eval'}.mp4",
                                    (masks.argwhere() + 1).flatten().tolist()
                                    if fake_eval and masks is not None
                                    else None,
                                )
                                for i, r in enumerate(recorders)
                            ]
                            if not fake_eval:
                                aps = torch.stack(achieved_pos, dim=0)
                                cps = torch.stack(command_pos, dim=0)
                                print(f"{len(achieved_pos)} achieved pos: {aps}")
                                print(f"{len(command_pos)} commanded pos: {cps}")
                                print(
                                    f"{len(obs_history)} obs history: {torch.stack(obs_history, dim=0)}"
                                )
                                print(
                                    f"{len(masked_coord_commands)} command history: {masked_coord_commands}"
                                )
                                torch.save(aps, f"{path_dir}/aps.pt")
                                torch.save(masks, f"{path_dir}/masks.pt")
                                torch.save(cps, f"{path_dir}/cps.pt")
                            print("eval end, exit")
                            exit(0)

                        input_corr = masked_coord_commands[
                            (len(obs_history) - pre_steps) : (
                                len(obs_history) - pre_steps + post_steps
                            )
                        ]

                        assert input_corr.shape == (post_steps, state_dim)

                        to_be_achieved = input_corr[0]
                        command_pos.append(to_be_achieved)

                        if masks is not None:
                            _masks = torch.cat(
                                (
                                    torch.zeros(
                                        (pre_steps + 1,),
                                        dtype=torch.bool,
                                    ),
                                    masks[
                                        (len(obs_history) - pre_steps) : (
                                            len(obs_history) - pre_steps + post_steps
                                        )
                                    ],
                                ),
                                dim=0,
                            )
                            if masks[
                                (len(obs_history) - pre_steps) : (
                                    len(obs_history) - pre_steps + post_steps
                                )
                            ][0]:
                                to_be_achieved = None
                                # ...
                            _masks = _masks.unsqueeze(0)
                            assert _masks.shape == (1, state_seq_len), _masks.shape
                        else:
                            _masks = None

                        action = take_action(
                            "eval",
                            torch.cat(
                                (history, current, input_corr),
                                dim=0,
                            ).unsqueeze(0),
                            env,
                            mask=_masks,
                        ).numpy(force=True)[0]

                    obs = torch.as_tensor(
                        extract_observation(t),  # .observation[:2],
                        dtype=torch.float32,
                        device=DEVICE,
                    )
                    obs_history.append(obs)
                    if to_be_achieved is not None:
                        error = (to_be_achieved - obs).abs().mean().item()
                        # print(f"tracking error: {error}")
                        tracking_errors.append(error)
                    # print(f'{cnt}: {extract_observation(t)}')
                    act_history.append(action)
                    cnt += 1
                    # if len(obs_history) > pre_steps:
                    #     # video_recorder.record(env)
                    #     [r.record(env, cnt) for r in recorders]
                    if fake_eval:
                        with env._physics.reset_context():
                            # self._task.initialize_episode(self._physics)
                            if to_be_achieved is None:
                                env._physics.data.qpos = 0
                            else:
                                if args["task"] == "shadow":
                                    env._physics.data.qpos[0] = 0
                                    env._physics.data.qpos[1:] = to_be_achieved.numpy(
                                        force=True
                                    )
                                else:
                                    env._physics.data.qpos = to_be_achieved.numpy(
                                        force=True
                                    )

                            env._physics.data.qvel = 0
                    return action

                if not fake_eval:
                    viewer.launch(env, width=1920, height=1080, policy=act)
                else:

                    class FakeTimeStep:
                        observation = torch.zeros((state_dim * 2,), dtype=torch.float32)

                    fake_time_step = FakeTimeStep()
                    while True:
                        act(fake_time_step)

    info = dict(
        commands=LONG,
        env=test_env,
        take_action=policy,
        rollout_length=5 * PARAMS["post_steps"],
        name=args["eval_name"],
    )
    eval(info)


if __name__ == "__main__":
    view(vars(args))
