from utils.trainer import OnlineTrainer, OfflineTrainer
from typing import Dict, Any, cast
from utils.reporter import get_reporter
import json
from utils.env import seed
from agent.args import args, DEVICE1, DEVICE2
from agent.algm import Robert
from agent.config import CONFIGS
from os import path
import torch
from agent.params import PARAMS
from agent.common import NEW_DATASET_FOLDER
from robert.agent.dataprocess import preprocess, enumer2, transform
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from utils.common import get_device

DEVICE2 = DEVICE2

DATASET_FOLDER = path.abspath(path.join(path.dirname(__file__), "../datasets"))


def analyze_state_action_seqs(state_seq, action_seq):
    ts, si, sl, dim = state_seq.shape
    _state_seq_means = torch.zeros((si, sl, dim), dtype=torch.float32)
    for i in range(si):
        _state_seq_means[i] = state_seq[:, i].mean(dim=0)
    _state_seq_mean = _state_seq_means.mean(dim=0)
    assert _state_seq_mean.shape == (sl, dim)

    _state_seq_stds = torch.zeros((si, sl, dim), dtype=torch.float32)
    for i in range(si):
        _state_seq_stds[i] = (
            ((state_seq[:, i] - _state_seq_mean) ** 2).mean(dim=0).sqrt()
        )

    _state_seq_std = (_state_seq_stds**2).mean(dim=0).sqrt()
    assert _state_seq_std.shape == (sl, dim)

    sl = sl - 1
    _action_seq_means = torch.zeros((si, sl, dim), dtype=torch.float32)
    for i in range(si):
        _action_seq_means[i] = action_seq[:, i].mean(dim=0)
    _action_seq_mean = _action_seq_means.mean(dim=0)
    assert _action_seq_mean.shape == (sl, dim)

    _action_seq_stds = torch.zeros((si, sl, dim), dtype=torch.float32)
    for i in range(si):
        _action_seq_stds[i] = (
            ((action_seq[:, i] - _action_seq_mean) ** 2).mean(dim=0).sqrt()
        )
    _action_seq_std = (_action_seq_stds**2).mean(dim=0).sqrt()
    assert _action_seq_std.shape == (sl, dim)

    print(f"state_seq mean is: {_state_seq_mean}, std is: {_state_seq_std}")
    print(f"action_seq mean is: {_action_seq_mean}, std is {_action_seq_std}")


def analyze_transformed_state_seqs(state_seq, action_seq):
    ts, si, sl, dim = state_seq.shape
    _state_seq_means = torch.zeros((si, sl, dim), dtype=torch.float32)
    for i in range(si):
        _state_seq_means[i] = transform(state_seq[:, i]).mean(dim=0)
    _state_seq_mean = _state_seq_means.mean(dim=0)
    assert _state_seq_mean.shape == (sl, dim)

    _state_seq_stds = torch.zeros((si, sl, dim), dtype=torch.float32)
    for i in range(si):
        _state_seq_stds[i] = (
            ((transform(state_seq[:, i]) - _state_seq_mean) ** 2).mean(dim=0).sqrt()
        )

    _state_seq_std = (_state_seq_stds**2).mean(dim=0).sqrt()
    assert _state_seq_std.shape == (sl, dim)

    # state_seqs = transform(state_seqs)
    print(
        f"state transformed seqs mean is: {_state_seq_mean}, std is: {_state_seq_std}"
    )

    _state_seq_means = torch.zeros((si, sl, dim), dtype=torch.float32)
    for i in range(si):
        _state_seq_means[i] = transform(state_seq[:, i]).abs().mean(dim=0)
    _state_seq_mean = _state_seq_means.mean(dim=0)
    assert _state_seq_mean.shape == (sl, dim)

    _state_seq_stds = torch.zeros((si, sl, dim), dtype=torch.float32)
    for i in range(si):
        _state_seq_stds[i] = (
            ((transform(state_seq[:, i]).abs() - _state_seq_mean) ** 2)
            .mean(dim=0)
            .sqrt()
        )

    _state_seq_std = (_state_seq_stds**2).mean(dim=0).sqrt()
    assert _state_seq_std.shape == (sl, dim)

    print(
        f"abs state transformed seqs mean is: {_state_seq_mean}, std is: {_state_seq_std}"
    )

    # single_step_delta = (state_seqs[:, 1:] - state_seqs[:, :-1]).abs()
    # print(
    #     f'mean of single step delta of state: {single_step_delta.mean(dim=0)}')
    # print(
    #     f'max of single step delta of state: {single_step_delta.max(dim=0)[0]}'
    # )
    # del single_step_delta

    # max_step_delta = (state_seqs[:, -1] - state_seqs[:, 0]).abs()
    # print(f'mean of max step delta of state: {max_step_delta.mean(dim=0)}')
    # print(f'max of max step delta of state: {max_step_delta.max(dim=0)[0]}')
    # del max_step_delta


# %%
def main(args: Dict[str, Any]):
    exp_name = json.dumps(args, indent=4, sort_keys=True)
    RANDOM_SEED = args["seed"]
    seed(RANDOM_SEED)

    states = torch.load(f'{NEW_DATASET_FOLDER}/{args["task"]}/states_train.pt')
    actions = torch.load(f'{NEW_DATASET_FOLDER}/{args["task"]}/actions_train.pt')

    states, actions = preprocess(states, actions)
    print(f"states.shape == {states.shape}")
    print(f"actions.shape == {actions.shape}")

    state_seq_len = PARAMS["pre_steps"] + PARAMS["post_steps"] + 1
    action_seq_len = PARAMS["pre_steps"] + PARAMS["post_steps"]
    state_seqs = enumer2(states, state_seq_len)
    action_seqs = enumer2(actions, action_seq_len)

    print("data preprocess finish")

    state_dim, action_dim = (CONFIGS["state_dim"], CONFIGS["action_dim"])
    _s = states.reshape((-1, state_dim))
    _a = actions.reshape((-1, action_dim))
    _s = _s[torch.randperm(_s.size(0))[:20000]]
    _a = _a[torch.randperm(_a.size(0))[:20000]]
    # _x, _y = _s[:, 0], _s[:, 1]
    for i in range(state_dim):
        data = pd.DataFrame(_s[:, i].numpy(force=True))
        plt.rcParams["figure.figsize"] = (12.8, 7.2)
        sns.histplot(data)
        plt.savefig(f"./data_analyze/{i}.png")
        # plt.show()
        plt.close()

    exit(0)
    print(f"state mean: {_s.mean(dim=0).tolist()}, state std: {_s.std(dim=0).tolist()}")
    print(f"min x: {_x.min(dim=0)[0].item()}, max x: {_x.max(dim=0)[0].item()}")
    print(f"min y: {_y.min(dim=0)[0].item()}, max y: {_y.max(dim=0)[0].item()}")

    print(
        f"action mean: {_a.mean(dim=0).tolist()}, action std: {_a.std(dim=0).tolist()}"
    )
    print(
        f"min action: {_a[:,0].min(dim=0)[0].item()}, {_a[:, 1].min(dim=0)[0].item()}, max action {_a[:, 0].max(dim=0)[0].item()}, {_a[:,1].max(dim=0)[0].item()}"
    )

    analyze_state_action_seqs(state_seqs, action_seqs)
    analyze_transformed_state_seqs(state_seqs, action_seqs)
    # exit(0)

    yon = input("print plot of state distribution?")

    if yon.lower() == "y":
        df = pd.DataFrame(
            {
                "x": (states[::1000].reshape((-1, 2)))[:, 0].tolist(),
                "y": (states[::1000].reshape((-1, 2)))[:, 1].tolist(),
                "i": list(range(states.size(1))) * int(states.size(0) / 1000),
            }
        )
        plt.rcParams["figure.figsize"] = (12.8, 7.2)
        sns.scatterplot(
            data=df,
            x="x",
            y="y",
            hue="i",
            palette=sns.color_palette("viridis", as_cmap=True),
        )
        plt.savefig("./state.png")
        plt.close()

    yon = input("print tsne of state seq distribution?")
    if yon.lower() == "y":
        print("start ploting tsne...")
        tsne = TSNE(n_components=2, learning_rate="auto", init="random", perplexity=3)

        _state_seqs = state_seqs.reshape((-1, state_seq_len * 2))
        _sampled_idx = torch.randint(
            0, _state_seqs.size(0), (int(_state_seqs.size(0) / 100),)
        )
        points = tsne.fit_transform(_state_seqs[_sampled_idx].numpy(force=True))

        df = pd.DataFrame(
            {
                "x": points[:, 0].tolist(),
                "y": points[:, 1].tolist(),
                # init='random'
            }
        )

        plt.rcParams["figure.figsize"] = (12.8, 7.2)
        sns.scatterplot(
            data=df,
            x="x",
            y="y",
        )
        plt.savefig("./tsne.png")
        plt.close()


if __name__ == "__main__":
    main(vars(args))
