import glob
import os
from collections import defaultdict
import pickle
import numpy as np
import ipdb
import re


def load_results(d, test_set):
    with open(os.path.join(d, f'saved_results_test{test_set}.pkl'), 'rb') as f:
        results = pickle.load(f)
    return results

def load_all_results(d):
    test_sets = [0,1,2,3]
    to_return = []
    for test_set in test_sets:
        result_i = load_results(d, test_set)
        to_return.append(result_i)
    return to_return
        
def load_args(d):
    with open(os.path.join(d, 'metadata.pkl'), 'rb') as f:
        args = pickle.load(f)['args']
    return args


def aggregate_list_of_results_reward(lst, reward_key='max_reward'):
    keys = list(lst[0].keys())
    to_return = dict()
    for key in keys:
        if 'rewards' not in key:
            continue
        key_aggr = []
        num_items = len(lst)
        #ipdb.set_trace()
        for item in lst:
            #k_value = item[key][0][reward_key] #reward_key = final_reward, max_reward, avg_reward
            k_value = item[key][-1][reward_key] #reward_key = final_reward, max_reward, avg_reward #TAKE THE LAST ITEM
            key_aggr.append(k_value)
        
        key_aggr = np.array(key_aggr)
        
        mu_key = np.mean(key_aggr, 0)
        std_key = np.std(key_aggr, 0)
        to_return[key] = [mu_key, std_key]

        
    return to_return

def aggregate_list_of_results(lst):
    keys = list(lst[0].keys())
    to_return = dict()
    for key in keys:
        if 'rewards' in key:
            continue
        key_aggr = []
        num_items = len(lst)
        for item in lst:
            k_value = item[key]
            key_aggr.append(k_value)
        
        key_aggr = np.array(key_aggr)
        
        mu_key = np.mean(key_aggr, 0)
        std_key = np.std(key_aggr, 0)
        to_return[key] = [mu_key, std_key]
        
    return to_return

#dir_i = 'meta-fun-maze-k-40-grid_size-6-encoder-none-use-img-0-num-hyper-layers-1-supp-sz-0-dataset-sz-1000-k-40-num-hypernet-layers-1-seed-3'


def create_aggregated_results(dirs, test_set, filter_hypernets=True, aggregate_rewards=False):
    total_results = defaultdict(list)
    all_args = dict()
    for d in dirs:
        try:
            args = load_args(d)
        except Exception as e:
            #ipdb.set_trace()
            print("Could not load args", d)
            print("Exception", e)
            continue
        model_name = d.split('/')[-1]
        #ipdb.set_trace()
        model_name_unseeded = model_name.split('-seed')[0]
        if filter_hypernets:
            model_name_unseeded = re.sub(r"-num-hyper-layers-[0-9]+", "", model_name_unseeded)
        
        #print("Model name unseeded", model_name_unseeded)
        try:
            #ipdb.set_trace()
            results = load_results(d, test_set)
            
            #results = load_all_results(d)
            
            total_results[model_name_unseeded].append(results)
            all_args[model_name_unseeded] = args
        except Exception as e:
            #ipdb.set_trace()
            print("Could not load results", d)
            print("Exception", e)
        
    aggregated_results = dict()
    #aggregated_results = defaultdict(list)
    for k in total_results.keys():
        list_of_results = total_results[k]
        
        # if 'meta-fun-maze-k-40-grid_size-6' in k:
        #     ipdb.set_trace()

        try:
            if not aggregate_rewards:
                aggr = aggregate_list_of_results(list_of_results)
            else:
                aggr = aggregate_list_of_results_reward(list_of_results)
        except Exception as e:
            print("Exception", e)
            print("error in", k)
        aggregated_results[k] = aggr
        #aggregated_results[k].append(aggr)

    return aggregated_results, all_args

def map_model_name(model, set_encoder):
    if model == 'hyper':
        return set_encoder
    else:
        return model

def create_list_of_dicts(all_args, aggregated_results, test_set, key_to_plot='mu_loss'):
    #ipdb.set_trace()
    #key_to_plot = 'mu_loss'
    list_of_dicts = []
    for k in aggregated_results.keys():
        args_k = vars(all_args[k])
        results_k = aggregated_results[k]


        #ipdb.set_trace()
        mu_loss, mu_std = results_k[key_to_plot]
        
        if 'reward' in key_to_plot:
            final_loss = mu_loss
            final_std = mu_std
        else:
            final_loss = mu_loss[-1]
            final_std = mu_std[-1]

        args_k[f'final_loss_{test_set}'] = final_loss
        args_k[f'final_std_{test_set}'] = final_std
        args_k[f'test_set'] = test_set

        args_k['new_model_name'] = map_model_name(args_k['model'], args_k['set_encoder'])

        args_k['key'] = k
        list_of_dicts.append(args_k)
    return list_of_dicts