# -*-coding=utf-8-*-
import os
import random
import sys

from baselines import logger
from matplotlib import pyplot as plt, animation
from sklearn.metrics import accuracy_score

path = os.path.abspath(os.path.dirname('./../__file__'))
sys.path.append(path)
import argparse
import ast
import numpy as np
import pandas as pd
import seaborn as sns
from StyleBC.model import BC
from datetime import datetime
from main_mod import get_env_config, set_random_seed, update_tb, create_base_path, update_logger, env_test

os.environ['PYTHONHASHSEED'] = '0'
TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())


def get_data(trajectory_file):
    expert_obs, expert_a, style_index = [], [], []
    f = open(trajectory_file, "r")
    trajectory_index = 0
    for line in f:
        if line.startswith("state"):
            s = ast.literal_eval(line.split(":")[1])
            expert_obs.append(s)
        elif line.startswith("action"):
            a = int(line.split(":")[1])
            expert_a.append(a)
        elif line.startswith("done"):
            style_index.append(trajectory_index)
            if "True" in line:
                trajectory_index += 1
    # action_dis(expert_obs, expert_a)
    return expert_obs, expert_a, style_index, trajectory_index


def get_ep_data(trajectory_file):
    ep_data = []
    expert_obs, expert_a, style_index = [], [], []
    f = open(trajectory_file, "r")
    trajectory_index = 0
    for line in f:
        if line.startswith("state"):
            s = ast.literal_eval(line.split(":")[1])
            expert_obs.append(s)
        elif line.startswith("action"):
            a = int(line.split(":")[1])
            expert_a.append(a)
        elif line.startswith("done"):
            style_index.append(trajectory_index)
            if "True" in line:
                ep_data.append({"state": expert_obs, "action": expert_a, "ep_idx": style_index})
                trajectory_index += 1
                expert_obs, expert_a, style_index = [], [], []
    # action_dis(expert_obs, expert_a)
    return ep_data, trajectory_index


def action_dis(state, action):
    state = pd.DataFrame(np.array(state))
    action = pd.DataFrame(np.array(action))
    action_str = action.astype("str")
    state_str = state.astype("str")
    state_str = state_str.iloc[:, 0].str.cat(state_str.iloc[:, 1:], sep="_")
    result = pd.concat([state_str, action_str], axis=1)
    result.to_csv("./result_str.csv")
    # state.to_csv("./state.csv")
    # action.to_csv("./action.csv")


def argparser():
    parser = argparse.ArgumentParser()
    # parser.add_argument("--trajectory_file", default="../random_demo_0701.txt", type=str, help="Path for source data")
    parser.add_argument("--trajectory_file", default="../grid_word_demo.txt", type=str, help="Path for source data")
    # parser.add_argument("--trajectory_file", default="../demo_all.txt", type=str, help="Path for source data")
    parser.add_argument('--env', type=str, default='two_way_grid_world', help='env')
    parser.add_argument('--gamma', default=0.95)
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument('--state_dim', type=int, default=4, help='state_dim')
    parser.add_argument('--batch_size', type=int, default=128, help='size of the mini batch in training and testing')
    parser.add_argument('--encode_dim', type=int, default=4, help='encode_dim')
    parser.add_argument('--action_dim', type=int, default=4, help='action_dim')
    parser.add_argument('--algo', type=str, default="stylebc", help='algorithm name')
    parser.add_argument('--max_sample_num', type=int, default=4e7, help='max steps interaction with the environment')
    parser.add_argument('--save_step_count', type=int, default=1000, help='the interval to save the model according to the train step')
    parser.add_argument('--tb_step_count', type=int, default=10, help='the interval to write the tensorboard file')
    parser.add_argument('--train_log_interval', type=int, default=10, help='the interval to save the training info to the tensorboard file')
    parser.add_argument('--test_count', type=float, default=2, help='test episode for each saved model')
    parser.add_argument('--add_reward', type=float, default=0.5, help='weight for the posterior reward')
    parser.add_argument('--learning_rate', type=float, default=1e-4, help='learning rate in training')
    parser.add_argument('--try_num', type=float, default=2, help='try count when running multiple times')
    parser.add_argument('--random_reset', type=bool, default=False, help='whether to random reset the env')
    parser.add_argument('--random_step', type=bool, default=False, help='whether to act random steps in env')
    parser.add_argument('--save_gif', type=bool, default=False, help='whether to save the test gif')
    parser.add_argument('--save_gif_interval', type=int, default=1, help='interval to save the gif')
    parser.add_argument('--test', type=bool, default=True, help='load and test the model')
    parser.add_argument('--load_path', type=str, default=r"D:\project\Style_BC\Style_BC_multi_door\stylebc_start\style_25_encode_4\stylebc_lr_1.00e-04_nu_4_msn_4.00e+07_rr_False_test_False_bs_128_sn_25_seed_6632_2021-08-11_10-32-11/checkpoint",
                        help='the path of the model to be loaded')
    parser.add_argument('--load_step', type=str, default="283000_39923696", help='the step of the model to be loaded')  #366000
    parser.add_argument('--render', type=bool, default=False, help='whether to render the game in the screen')
    parser.add_argument('--end', type=str, default="14_7", help='the string form of the env ends')
    parser.add_argument('--style_num', type=int, default=25, help='the number of the behavior styles')   #######
    return parser.parse_args()


def get_batch_data(model, ep_data, style_num, batch_size=128):
    current_batch_size = 0
    ret_ep_data = []
    while current_batch_size < batch_size:
        selected_ep_data = random.choice(ep_data)
        ret_ep_data.append(selected_ep_data)
        current_batch_size += len(selected_ep_data["action"])
    expert_obs, expert_a, expert_ep_index = get_ep_style_idx(model, ret_ep_data, style_num)
    expert_observations = np.array(expert_obs)
    expert_actions = np.array(expert_a)
    expert_onehot_actions = np.eye(args.action_dim)[expert_actions]
    expert_ep_index = np.array(expert_ep_index)
    return expert_observations, expert_actions, expert_onehot_actions, expert_ep_index


def get_ep_style_idx(model, ep_data, style_num):
    ret_obs, ret_a, ret_style_idx = [], [], []
    for data in ep_data:
        highest_acc = 0
        highest_acc_style_idx = None
        ret_obs.extend(data["state"])
        ret_a.extend(data["action"])
        for idx, style_idx in enumerate(range(style_num)):
            obs = np.array(data["state"])
            actions = np.array(data["action"])
            styles = np.ones_like(actions) * style_idx
            predicted_action, train_action_prob = model.get_action(obs, styles, stochastic=False)
            acc = accuracy_score(actions, predicted_action)
            if acc >= highest_acc:
                highest_acc = acc
                highest_acc_style_idx = styles
        if highest_acc_style_idx is None:
            raise ValueError(f"Get no style idx for episode data {data}.")
        else:
            ret_style_idx.extend(list(highest_acc_style_idx))
    return ret_obs, ret_a, ret_style_idx


def main(args, **kwargs):
    if kwargs.get("seed", None) is not None:
        set_random_seed(kwargs["seed"])
        path_params = {"seed": kwargs["seed"]}
    else:
        set_random_seed(args.seed)
        path_params = {"seed": args.seed}
    if args.random_reset:
        #args.trajectory_file = "../demo_all.txt"
        #args.trajectory_file = "../random_start_top_bottom.txt"
        args.trajectory_file = "../test_random8.txt"
        args.end = "14_7"
        base_path = create_base_path("../stylebc_result/result", args, **path_params)
    else:
        args.trajectory_file = "../start_25.txt"
        args.end = "14_7"
        base_path = create_base_path("../final_result/style_25", args, **path_params)
    style_num = args.style_num
    batch_size = args.batch_size
    #base_path = create_base_path("../stylebc_result/result", args, **path_params)
    env, config = get_env_config(args)
    ep_data, episode_num = get_ep_data(args.trajectory_file)
    policy = BC('policy', args.state_dim, args.action_dim, args.encode_dim, style_num, learning_rate=args.learning_rate,
                save_path=base_path["save_path"], save_style_dir=base_path["save_style_dir"], tb_dir=base_path["tb_path"])
    if args.test:
        policy.load(args.load_path, args.load_step)
    writer = policy.log_writer
    logger.configure(dir=base_path["logger_path"])
    train_step = 0
    sample_cost = 0
    while sample_cost < args.max_sample_num:
        if not args.test:
            expert_observations, expert_actions, expert_onehot_actions, expert_ep_index = get_batch_data(policy, ep_data, style_num, batch_size)
            sample_cost += expert_onehot_actions.shape[0]
            train_step, target_loss, style_step, style_loss, policy_loss = policy.train(expert_observations, expert_onehot_actions, expert_ep_index)
            train_action, train_action_prob = policy.get_action(expert_observations, expert_ep_index, stochastic=False)
            acc = accuracy_score(expert_actions, train_action)
            target_train_action, target_train_action_prob = policy.get_target_action(expert_observations, stochastic=False)
            target_acc = accuracy_score(expert_actions, target_train_action)
            log_ep_data = {
                "loss/target_loss": target_loss,
                "loss/style_loss": style_loss,
                "loss/policy_loss": policy_loss,
                "train/acc": acc,
                "train/target_acc": target_acc,
                "train/acc_diff": (acc - target_acc) / target_acc,
                "sample/sample_cost": sample_cost,
                "sample/train_step": train_step
            }
        else:
            train_step += args.save_step_count
            sample_cost = 0
            log_ep_data = {}
        if train_step % args.save_step_count == 0 or args.test:
            if not args.test:
                policy.save(str(train_step) + "_" + str(sample_cost))
                policy.save_styles(str(train_step) + "_" + str(sample_cost))
                print('Model saved %s' % (str(train_step) + "_" + str(sample_cost)))
            test_rewards = {}
            test_success_rate = {}
            #for test_label in ["target", "policy"]:
            for test_label in ["policy"]:
                test_log_data = env_test(policy, None, None, args, test_label=test_label, style_num=style_num, base_path=base_path, train_step=train_step)
                update_logger(test_log_data)
                update_tb(writer, sample_cost, test_log_data)
                test_rewards[test_label] = test_log_data[f"test_{test_label}/ep_reward"]
                test_success_rate[test_label] = test_log_data[f"test_{test_label}/success_rate"]
            test_diff_log = {"test_policy/reward_dff": test_rewards["policy"] - test_rewards["target"],
                             "test_policy/success_rate_dff": test_success_rate["policy"] - test_success_rate["target"]}
            update_logger(test_diff_log)
            update_tb(writer, sample_cost, test_diff_log)
        if train_step % args.tb_step_count == 0:
            if not args.test:
                update_tb(writer, sample_cost, log_ep_data)
            writer.flush()
            update_logger(log_ep_data)
            logger.dump_tabular()
    writer.close()
    return base_path


def plot(base_path_list, args, plot="error_band", smooth_window_size=100, down_sampling_interval=10):
    # only the supervised accuracy with the same seed
    first_file_list1 = ["../result/stylebc_learning_rate_0.0001_encode_dim_6_max_sample_num_10000000.0_seed_3566_2021-07-21_18-57-47/",
                        "../result/stylebc_learning_rate_0.0001_encode_dim_6_max_sample_num_10000000.0_seed_3566_2021-07-21_18-44-23/",
                        "../result/stylebc_learning_rate_0.0001_encode_dim_6_max_sample_num_10000000.0_seed_3566_2021-07-21_18-31-10/",
                        "../result/stylebc_learning_rate_0.0001_encode_dim_6_max_sample_num_10000000.0_seed_3566_2021-07-21_18-17-58/",
                        "../result/stylebc_learning_rate_0.0001_encode_dim_6_max_sample_num_10000000.0_seed_3566_2021-07-21_18-04-42/"]
    # only the supervised accuracy with five different seed
    first_file_list2 = ["../result/stylebc_learning_rate_0.0001_encode_dim_6_max_sample_num_10000000.0_seed_3411_2021-07-22_00-40-46/",
                        "../result/stylebc_learning_rate_0.0001_encode_dim_6_max_sample_num_10000000.0_seed_9305_2021-07-22_00-27-27/",
                        "../result/stylebc_learning_rate_0.0001_encode_dim_6_max_sample_num_10000000.0_seed_3189_2021-07-22_00-14-15/",
                        "../result/stylebc_learning_rate_0.0001_encode_dim_6_max_sample_num_10000000.0_seed_2824_2021-07-22_00-01-10/",
                        "../result/stylebc_learning_rate_0.0001_encode_dim_6_max_sample_num_10000000.0_seed_1717_2021-07-21_23-48-06/"]
    # supervised accuracy and the environment test with five different seed
    first_file_list3 = ["../result/stylebc_learning_rate_0.0001_encode_dim_6_max_sample_num_10000000.0_random_reset_False_seed_4822_2021-07-22_21-24-21/",
                        "../result/stylebc_learning_rate_0.0001_encode_dim_6_max_sample_num_10000000.0_random_reset_False_seed_930_2021-07-22_21-17-04/",
                        "../result/stylebc_learning_rate_0.0001_encode_dim_6_max_sample_num_10000000.0_random_reset_False_seed_6172_2021-07-22_21-10-07/",
                        "../result/stylebc_learning_rate_0.0001_encode_dim_6_max_sample_num_10000000.0_random_reset_False_seed_1196_2021-07-22_21-03-12/",
                        "../result/stylebc_learning_rate_0.0001_encode_dim_6_max_sample_num_10000000.0_random_reset_False_seed_9213_2021-07-22_20-56-18/"]
    if isinstance(base_path_list, list):
        if len(base_path_list) > 0:
            file_list = base_path_list
        else:
            # file_list = first_file_list1
            # file_list = first_file_list2
            file_list = first_file_list3
    else:
        file_list = []
        for filename in os.listdir(base_path_list):
            file_list.append(os.path.join(base_path_list, filename) + "/")
    if plot in ["error_band", "both"]:
        # plot acc
        plot_config = {
            "target_label": "acc",
            "source_label": ["train/acc", "train/target_acc"],
            "legends": ["policy", "target"],
            "smooth_window_size": 10,
            "down_sampling_interval": 10
        }
        plot_logger_error_line(file_list, **plot_config)
        # # plot reward
        plot_config = {
            "target_label": "reward",
            "source_label": ["test_policy/ep_reward", "test_target/ep_reward"],
            "legends": ["policy", "target"],
            "smooth_window_size": 10,
            "down_sampling_interval": 1
        }
        plot_logger_error_line(file_list, **plot_config)
    elif plot in ["style", "both"]:
        plot_style(file_list, args)


def plot_logger_error_line(file_list, smooth_window_size=100, down_sampling_interval=10, **kwargs):
    total_df = None
    target_label = kwargs.get("target_label", "acc")
    source_label = kwargs.get("source_label", ["test_policy/ep_reward", "test_target/ep_reward"])
    legends = kwargs.get("legends", ["policy", "target"])
    revert_source_label = list(reversed(source_label))
    print(f"=================== {target_label} ===============================")
    print(file_list)
    for idx, file in enumerate(file_list):
        filename = file + "logger/progress.csv"
        df = pd.read_csv(filename)
        df = df[["sample/train_step"] + source_label]
        df = df.dropna(axis=0, how="any")
        if down_sampling_interval > 1:
            df = df.iloc[list([down_sampling_interval * k for k in range(int(len(df) / down_sampling_interval))])]
        df1 = df[["sample/train_step"] + source_label]
        df1.columns = ["train_step", target_label, "name"]
        df1["name"] = legends[0]
        df1[target_label] = df1[target_label].rolling(smooth_window_size).mean()
        df2 = df[["sample/train_step"] + revert_source_label]
        df2.columns = ["train_step", target_label, "name"]
        df2["name"] = legends[1]
        df2[target_label] = df2[target_label].rolling(smooth_window_size).mean()
        new_df = pd.concat([df1, df2], axis=0)
        if total_df is None:
            total_df = new_df
        else:
            total_df = pd.concat([total_df, new_df], axis=0)
        print(f"Finish {idx}/{len(file_list)}")
    sns_plot = sns.lineplot(x="train_step", y=target_label,
                            hue="name",
                            data=total_df)
    sns_plot.set_title(target_label)
    plt.show()
    os.makedirs("./plot", exist_ok=True)
    sns_plot.figure.savefig(f"./plot/{target_label}_{TIMESTAMP}.png")


def plot_style(file_list, args):
    style_dict = {}
    for idx, file in enumerate(file_list):
        max_sample_number = 0
        target_file_name = ""
        dir_path = file + "style/"
        file_name_list = os.listdir(dir_path)
        for file_name in file_name_list:
            if file_name.startswith("style_softmax_"):
                local_sample_num = int(file_name.split("_")[3].split(".")[0])
                if local_sample_num > max_sample_number:
                    max_sample_number = local_sample_num
                    target_file_name = os.path.join(dir_path, file_name)
        local_style = np.loadtxt(target_file_name)
        for style_index in range(local_style.shape[0]):
            style_dict.setdefault(style_index, []).append(local_style[style_index])
    for k, v in style_dict.items():
        if k < 5:
            local_df = pd.DataFrame(v)
            local_df.T.plot.bar()
            plt.show()
        else:
            break


if __name__ == '__main__':
    args = argparser()
    # plot([], args)
    # plot("../stylebc_result/result", args, "style")
    try_num = args.try_num
    base_path = []
    if try_num > 1:
        for idx in range(try_num):
            seed = random.randint(0, 10000)
            local_base_path = main(args, seed=seed)
            base_path.append(local_base_path["base_path"])
    else:
        local_base_path = main(args)
        base_path.append(local_base_path["base_path"])
    if not args.test:
        plot(base_path, args, "error_band")
