from collections import defaultdict
import numpy as np
import os
import json


def sliding_mean(data_array, window=5):
    """Sliding average"""
    new_list = []
    for i in range(len(data_array)):
        indices = range(max(i - window + 1, 0),
                        min(i + window + 1, len(data_array)))
        avg = 0
        for j in indices:
            avg += data_array[j]
        avg /= float(len(indices))
        new_list.append(avg)

    return new_list


def get_file_path(data_path):
    files = os.listdir(data_path)
    file_path = []
    for _dir in files:
        if not _dir == "_sources":
            file_path.append(os.path.join(data_path, _dir))
    #file_path=[os.path.join(data_path,_dir) for _dir in files]
    return file_path


def raw_data_loader(file_path_list, clip, metric):
    data = []
    for file_path in file_path_list:
        try:
            with open(os.path.join(file_path, 'info.json')) as f:
                d = json.load(f)

                if metric not in d.keys():
                    continue

                # if len([_d['value'] if isinstance(_d, dict) else _d for _d in d[metric]]) < clip:
                #     continue

        except FileNotFoundError:
            print(f'FileNotFoundError: {file_path}')
            continue
        except json.JSONDecodeError:
            print(f'Error, path: {file_path}')
            continue

        data.append(d)
    if len(data) == 0:
        raise ValueError(
            f'data is empty, file_path_list: {file_path_list}, clip: {clip}, metric: {metric}')
    return data


def raw_data_processor(data_list, align=True, sliding_mean_config=None, metric='adv_eval_battle_won_mean'):
    _processed_data, processed_data = [], []
    min_len = 1000000000  # a very large int
    threshold = 1

    for data in data_list:
        if metric == 'test_return_mean':
            data[metric] = [d['value'] if isinstance(
                d, dict) else d for d in data[metric]]

        if len(data[metric]) < threshold:
            continue

        _processed_data.append(data[metric])
        if len(data[metric]) <= min_len:
            min_len = len(data[metric])

    if align:
        for data in _processed_data:
            processed_data.append(data[:min_len])
    else:
        processed_data = _processed_data

    if sliding_mean_config['sliding']:
        for i, data in enumerate(processed_data):
            processed_data[i] = sliding_mean(
                data, window=sliding_mean_config['window_size'])

    #print(f'# seeds: {len(processed_data)}')
    return processed_data


def get_mean_std_for_each_algo(data_path,
                               metric,
                               clip=0,
                               sliding_mean_config={
                                   'sliding': True, 'window_size': 5},
                               return_processed_data=False,
                               use_median=False):
    file_path = get_file_path(data_path)

    raw_data = raw_data_loader(
        file_path_list=file_path, clip=clip, metric=metric)
    processed_data = raw_data_processor(
        data_list=raw_data, align=True, metric=metric, sliding_mean_config=sliding_mean_config)

    algo_data = np.array(processed_data) * (100 if metric ==
                                            'test_battle_won_mean' else 1)  # percent %
    xticks = True if metric == 'adv_eval_battle_won_mean_T' else False
    if not xticks:
        if use_median:
            algo_data_mean = np.median(algo_data, axis=0)
            algo_data_q75, algo_data_q25 = np.percentile(
                algo_data, [75, 25], axis=0)
        else:
            algo_data_mean = np.mean(algo_data, axis=0)
            algo_data_std = np.std(algo_data, axis=0)

        if clip != 0:
            if use_median:
                algo_data_mean, algo_data_q75, algo_data_q25 = algo_data_mean[
                    :clip], algo_data_q75[:clip], algo_data_q25[:clip]
            else:
                algo_data_mean, algo_data_std = algo_data_mean[:clip], algo_data_std[:clip]

        if return_processed_data:
            if use_median:
                return algo_data_mean, algo_data_q75, algo_data_q25, processed_data
            return algo_data_mean, algo_data_std, processed_data

        return algo_data_mean, algo_data_std
    else:
        x_mean = np.mean(algo_data, axis=0)
        return x_mean
