"""Script to convert multiple training runs to single csv."""

import os
import argparse
import pandas as pd
import numpy as np
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
from scipy.stats import bootstrap
import logging
import glob

ROOT = os.path.dirname(os.path.abspath(__file__)) + '/../../'
logger = logging.getLogger(__name__)
ALGORITHMS = ['Continuous/TD3', 'Continuous/SAC', 'Discrete/DQN', 'Continuous/PPO', 'Discrete/PPO']


def evaluate_data(array, window_size=1):
    """Create a data evaluation from the given array.

    Returns:
        mean: mean of the array
        std: standard deviation of the array
        bootstrap025: lower bound of the 95% confidence interval in the mean metric
        bootstrap975: upper bound of the 95% confidence interval in the mean metric
    """
    if window_size % 2 == 0:
        window_size += 1
    half_window = (window_size-1) // 2
    mean = np.zeros(array.shape[1])
    std = np.zeros(array.shape[1])
    bootstrap025 = np.zeros(array.shape[1])
    bootstrap975 = np.zeros(array.shape[1])
    for i in range(array.shape[1]):
        lower_index = max(0, i-half_window)
        upper_index = min(array.shape[1], i+half_window+1)
        mean[i] = np.nanmean(array[:, lower_index:upper_index])
        std[i] = np.nanstd(array[:, lower_index:upper_index])
        # We don't give the entire window in the bootstrap function as it skews the results towards being too confident in the results.
        data = np.concatenate(array[:, lower_index:upper_index], axis=0)
        data = data[~np.isnan(data)]
        res = bootstrap(data=data[np.newaxis, :],
                        statistic=np.mean,
                        confidence_level=0.95,
                        axis=0,
                        n_resamples=10000)
        bootstrap025[i] = np.nan_to_num(res.confidence_interval.low, mean[lower_index])
        bootstrap975[i] = np.nan_to_num(res.confidence_interval.high, mean[lower_index])
    return mean, std, bootstrap025, bootstrap975


def create_df(step, mean_reward, std_reward, reward_bootstrap025, reward_bootstrap975, mean_safety, std_safety, safety_bootstrap025, safety_bootstrap975, is_baseline=False):
    """Create a new dataframe with the calculated values.

    For safe agents:
    episode	 mean_reward	 std_reward	 mean_safety_activity	 std_safety_activity
    For baseline agents:
    episode	 mean_reward	 std_reward	 mean_safety_violation	 std_safety_violation

    Returns: a dataframe with the calculated values.
    """
    if is_baseline:
        return pd.DataFrame({
            'step': step,
            'mean_reward': mean_reward,
            'std_reward': std_reward,
            'bootstrap025_reward': reward_bootstrap025,
            'bootstrap975_reward': reward_bootstrap975,
            'mean_safety_violation': mean_safety,
            'std_safety_violation': std_safety,
            'bootstrap025_safety_violation': safety_bootstrap025,
            'bootstrap975_safety_violation': safety_bootstrap975,
        })
    else:
        return pd.DataFrame({
            'step': step,
            'mean_reward': mean_reward,
            'std_reward': std_reward,
            'bootstrap025_reward': reward_bootstrap025,
            'bootstrap975_reward': reward_bootstrap975,
            'mean_safety_activity': mean_safety,
            'std_safety_activity': std_safety,
            'bootstrap025_safety_activity': safety_bootstrap025,
            'bootstrap975_safety_activity': safety_bootstrap975,
        })


def average_all_seeds(base_folder="Default", window_size=1):
    """Convert multiple training runs to single csv.

    Take in a path to a folder containing multiple training run seeds,
    and output a csv file with columns
    step, mean_reward, std_reward, mean_safety, std_safety.
    """
    tag_base = 'benchmark_train/'
    input_path = ROOT + 'tensorboard/Train/' + base_folder
    output_path = ROOT + 'data/Train/' + base_folder + '/'

    group0 = glob.glob(input_path + f'/**/{ALGORITHMS[0]}', recursive=True)
    dirs0 = [dir for dir in os.listdir(group0[0]) if os.path.isdir(group0[0])]
    summary_iterator = EventAccumulator(group0[0] + "/" + dirs0[0]).Reload()
    first_step_array = pd.DataFrame.from_records(
            summary_iterator.Scalars(tag_base + 'avg_env_reward'),
            columns=summary_iterator.Scalars(tag_base + 'avg_env_reward')[0]._fields)["step"]
    n_steps = first_step_array.shape[0]
    reward_data_dict = dict()
    safety_data_dict = dict()
    reward_data_dict_PPO = dict()
    safety_data_dict_PPO = dict()
    for alg in ALGORITHMS:
        group = glob.glob(input_path + f'/**/{alg}', recursive=True)
        if len(group) == 0:
            logger.warning(f"No runs found for {alg}.")
            continue
        for g in group:
            # get the last 4 folders from the path g
            last_part_path = os.path.join(*g.split('/')[-4:-1])
            alg_tuple = os.path.join(*g.split('/')[-4:-2])
            algorithm_str = os.path.join(*g.split('/')[-1:])
            dirs = [dir for dir in os.listdir(g) if os.path.isdir(g)]
            v1, v2 = np.zeros(shape=(len(dirs))), np.zeros(shape=(len(dirs)))
            if 'Baseline' in g:
                tag_sup = f'{tag_base}is_safety_violation'
                is_baseline = True
            else:
                tag_sup = f'{tag_base}avg_safety_activity'
                is_baseline = False
            # create an empty list to store the numpy arrays
            # create a list of subfolders in the given path
            n_seeds = len(dirs)
            reward_arrays = np.full((n_seeds, n_steps), np.nan)
            safety_arrays = np.full((n_seeds, n_steps), np.nan)

            for i, dir in enumerate(dirs):
                summary_iterator = EventAccumulator(g + '/' + dir).Reload()
                reward = pd.DataFrame.from_records(
                    summary_iterator.Scalars(tag_base + 'avg_env_reward'),
                    columns=summary_iterator.Scalars(tag_base + 'avg_env_reward')[0]._fields)["value"].values
                step_array = pd.DataFrame.from_records(
                    summary_iterator.Scalars(tag_base + 'avg_env_reward'),
                    columns=summary_iterator.Scalars(tag_base + 'avg_env_reward')[0]._fields)["step"]
                safety = pd.DataFrame.from_records(
                    summary_iterator.Scalars(tag_sup),
                    columns=summary_iterator.Scalars(tag_sup)[0]._fields)["value"].values
                steps = pd.DataFrame.from_records(
                    summary_iterator.Scalars(tag_sup),
                    columns=summary_iterator.Scalars(tag_sup)[0]._fields)["step"].values
                if len(steps) < n_steps:
                    logger.warning(f"Number of steps ({len(steps)}) in {g}/{dir} is lower than the expected number of steps ({n_steps}).")
                elif len(steps) > n_steps:
                    logger.warning(f"Number of steps ({len(steps)}) in {g}/{dir} exceeds the expected number of steps ({n_steps}).")
                    reward = reward[:n_steps]
                    step_array = step_array[:n_steps]
                    safety = safety[:n_steps]
                    steps = steps[:n_steps]
                _, idx_step, _ = np.intersect1d(step_array, steps, assume_unique=True, return_indices=True)
                reward_arrays[i][:len(reward)] = reward
                safety_arrays[i][idx_step] = safety

            # concatenate the numpy arrays on a new axis
            reward_array = np.stack(reward_arrays, axis=0)
            safety_array = np.stack(safety_arrays, axis=0)
            if alg_tuple not in reward_data_dict:
                reward_data_dict[alg_tuple] = reward_array
                safety_data_dict[alg_tuple] = safety_array
            else:
                reward_data_dict[alg_tuple] = np.concatenate((reward_data_dict[alg_tuple], reward_array), axis=0)
                safety_data_dict[alg_tuple] = np.concatenate((safety_data_dict[alg_tuple], safety_array), axis=0)
            # Make sure the window size is odd
            mean_reward, std_reward, reward_bootstrap025, reward_bootstrap975 = evaluate_data(reward_array, window_size)
            mean_safety, std_safety, safety_bootstrap025, safety_bootstrap975 = evaluate_data(safety_array, window_size)
            step = first_step_array
            # >>> Write data to csv file <<<
            output_df = create_df(
                step,
                mean_reward,
                std_reward,
                reward_bootstrap025,
                reward_bootstrap975,
                mean_safety,
                std_safety,
                safety_bootstrap025,
                safety_bootstrap975,
                is_baseline
            )
            # get name of last folder in path_to_folder
            os.makedirs(output_path+last_part_path, exist_ok=True)
            output_df.to_csv(output_path+last_part_path+'/'+algorithm_str+'.csv', index=False)
            if 'PPO' in alg:
                if alg_tuple not in reward_data_dict_PPO:
                    reward_data_dict_PPO[alg_tuple] = reward_array
                    safety_data_dict_PPO[alg_tuple] = safety_array
                else:
                    reward_data_dict_PPO[alg_tuple] = np.concatenate((reward_data_dict_PPO[alg_tuple], reward_array), axis=0)
                    safety_data_dict_PPO[alg_tuple] = np.concatenate((safety_data_dict_PPO[alg_tuple], safety_array), axis=0)

    # Create table with average over all algorithms
    for alg_tuple in reward_data_dict:
        mean_reward, std_reward, reward_bootstrap025, reward_bootstrap975 = evaluate_data(reward_data_dict[alg_tuple], window_size)
        mean_safety, std_safety, safety_bootstrap025, safety_bootstrap975 = evaluate_data(safety_data_dict[alg_tuple], window_size)
        step = first_step_array
        is_baseline = 'Baseline' in alg_tuple
        # >>> Write data to csv file <<<
        output_df = create_df(
            step,
            mean_reward,
            std_reward,
            reward_bootstrap025,
            reward_bootstrap975,
            mean_safety,
            std_safety,
            safety_bootstrap025,
            safety_bootstrap975,
            is_baseline
        )
        # get name of last folder in path_to_folder
        os.makedirs(output_path+alg_tuple, exist_ok=True)
        output_df.to_csv(output_path+alg_tuple+'/all_algorithms.csv', index=False)
    # Create table with average over all PPO algorithms (cont and disc)
    for alg_tuple in reward_data_dict_PPO:
        mean_reward, std_reward, reward_bootstrap025, reward_bootstrap975 = evaluate_data(reward_data_dict_PPO[alg_tuple], window_size)
        mean_safety, std_safety, safety_bootstrap025, safety_bootstrap975 = evaluate_data(safety_data_dict_PPO[alg_tuple], window_size)
        step = first_step_array
        is_baseline = 'Baseline' in alg_tuple
        # >>> Write data to csv file <<<
        output_df = create_df(
            step,
            mean_reward,
            std_reward,
            reward_bootstrap025,
            reward_bootstrap975,
            mean_safety,
            std_safety,
            safety_bootstrap025,
            safety_bootstrap975,
            is_baseline
        )
        # get name of last folder in path_to_folder
        os.makedirs(output_path+alg_tuple, exist_ok=True)
        output_df.to_csv(output_path+alg_tuple+'/both_ppo.csv', index=False)

def main():
    """Parse command line arguments and call the csv function."""
    # create a parser for command line arguments
    parser = argparse.ArgumentParser(
        description='Calculate mean and standard deviation of tensorflow files over all seeds.')
    parser.add_argument('path', type=str, help='Path to folder containing tensorflow files with multiple seeds.')
    parser.add_argument('--window_size', type=int, help='Window size for rolling average.', required=False, default=1)

    # parse the command line arguments
    args = parser.parse_args()

    # call the average_all_seeds function with the parsed arguments
    average_all_seeds(args.path, args.window_size)


if __name__ == '__main__':
    main()
