import os, sys
import glob
import argparse
import pickle
import random

import numpy as np
import pandas as pd
import torch

def set_seed_everything(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed) 
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

def concat_and_out(
    concat_max_level,
    in_dir, out_dir, out_filename):
    data_parent_dir = in_dir
    out_dir = out_dir
    os.makedirs(out_dir, exist_ok=True)

    data_to_find_path=os.path.join(
        data_parent_dir, '**', '*.pickle')

    res_pickle_list = glob.glob(
        data_to_find_path,
        recursive=True)
    levels_list = list(range(concat_max_level + 1))
    all_res_list = []
    for _i_path_pickle in res_pickle_list:
      with open(_i_path_pickle, 'rb') as f:
        res = pickle.load(f)
        all_res_list.append(res)

    res_all_df = pd.concat(all_res_list, axis=0)

    col_sufix='median'
    median_res_all=res_all_df.groupby(axis=0, level=levels_list).median()
    median_res_all.rename(
      columns=lambda col: '_'.join([col, col_sufix]),
      inplace=True)
    
    q=0.25
    col_sufix=f'{q*100:03.0f}-quantile'
    Q1_res_all=res_all_df.groupby(axis=0, level=levels_list).quantile(q)
    Q1_res_all.rename(
      columns=lambda col: '_'.join([col, col_sufix]),
      inplace=True)

    q=0.75
    col_sufix=f'{q*100:03.0f}-quantile'
    Q3_res_all=res_all_df.groupby(axis=0, level=levels_list).quantile(q)
    Q3_res_all.rename(
      columns=lambda col: '_'.join([col, col_sufix]),
      inplace=True)

    col_sufix='count'
    count_res_all=res_all_df.groupby(axis=0, level=levels_list).count()
    count_res_all.rename(
      columns=lambda col: '_'.join([col, col_sufix]),
      inplace=True)

    res_summary_df=pd.concat([
      median_res_all, Q1_res_all, Q3_res_all, count_res_all], axis=1)
    res_summary_df.sort_index(axis=1, inplace=True)

    #out_filename_base = os.path.basename(os.path.normpath(out_filename))
    out_filename_base = out_filename
    res_summary_df.to_csv(
        os.path.join(out_dir, 'summary_' + out_filename_base + '.csv'))
    res_all_df.to_csv(

        os.path.join(out_dir, 'all_' + out_filename_base + '.csv'))
# %%

def concat_and_out_all(concat_max_level, in_dir, out_dir, out_filename):
    data_parent_dir = in_dir
    out_dir = out_dir
    os.makedirs(out_dir, exist_ok=True)

    data_to_find_path=os.path.join(
        data_parent_dir, '**', '*.pickle')

    res_pickle_list = glob.glob(
        data_to_find_path,
        recursive=True)
    levels_list = list(range(concat_max_level + 1))

    all_res_list = []
    for _i_path_pickle in res_pickle_list:
      with open(_i_path_pickle, 'rb') as f:
        res = pickle.load(f)
        all_res_list.append(res)

    res_all_df = pd.concat(all_res_list, axis=0)

    col_sufix='mean'
    mean_res_all=res_all_df.groupby(axis=0, level=levels_list).mean()
    mean_res_all.rename(
      columns=lambda col: '_'.join([col, col_sufix]),
      inplace=True)
    
    col_sufix='median'
    median_res_all=res_all_df.groupby(axis=0, level=levels_list).median()
    median_res_all.rename(
      columns=lambda col: '_'.join([col, col_sufix]),
      inplace=True)
    
    q=0.25
    col_sufix=f'{q*100:03.0f}-quantile'
    Q1_res_all=res_all_df.groupby(axis=0, level=levels_list).quantile(q)
    Q1_res_all.rename(
      columns=lambda col: '_'.join([col, col_sufix]),
      inplace=True)

    q=0.75
    col_sufix=f'{q*100:03.0f}-quantile'
    Q3_res_all=res_all_df.groupby(axis=0, level=levels_list).quantile(q)
    Q3_res_all.rename(
      columns=lambda col: '_'.join([col, col_sufix]),
      inplace=True)

    col_sufix='std'
    std_res_all=res_all_df.groupby(axis=0, level=levels_list).std()
    std_res_all.rename(
      columns=lambda col: '_'.join([col, col_sufix]),
      inplace=True)

    col_sufix='count'
    count_res_all=res_all_df.groupby(axis=0, level=levels_list).count()
    count_res_all.rename(
      columns=lambda col: '_'.join([col, col_sufix]),
      inplace=True)

    res_summary_df=pd.concat([
      median_res_all, Q1_res_all, Q3_res_all,
      mean_res_all, std_res_all, count_res_all], axis=1)
    res_summary_df.sort_index(axis=1, inplace=True)

    #out_filename_base = os.path.basename(os.path.normpath(out_filename))
    out_filename_base = out_filename
    res_summary_df.to_csv(
        os.path.join(out_dir, 'summary_' + out_filename_base + '.csv'))
    res_all_df.to_csv(

        os.path.join(out_dir, 'all_' + out_filename_base + '.csv'))


