import numpy as np
import itertools
import os
import pickle
import postprocess.util
import postprocess.dataloader

baselines = ['ts', 'ucb', 'greedy', 'egreedy']
llms = ['gpt4']

class CotDataLoader(postprocess.dataloader.DataLoader):
    base_dirs = ['../stacked_results/', '../results/']
    baseline_pref = 'baselines'

    def __init__(self):
        self.T = 200
        self.K = 5
        self.delta = '0.2'
        self.reps = 1000
        self.algs = ['ts', 'ucb', 'greedy', 'egreedy', 'gpt4']
        self.llm_pref = 'buttons'
        self.etas = [1.0]
        self.n0s = [1]
        self.epss = [0.1]
        self.configs = [
            {'suggestive': 'sug',
             'summarized': 'raw',
             'cot': 'cotn',
             'dist': 'uni',
             'temp': '0'},
            {'suggestive': 'sug',
             'summarized': 'sum',
             'cot': 'cotn',
             'dist': 'uni',
             'temp': '0'}
        ]

        self.baselines = [x for x in baselines if x in self.algs]
        self.llms = [x for x in llms if x in self.algs]

        self.load_results()
        self.summarize()


    def get_baseline_configs(self,alg):
        out = {}
        if alg == 'ts':
            out['ts'] = 'ts'
        elif alg == 'greedy':
            for n0 in self.n0s:
                out[f"greedy_{n0}"] = f"greedy_{n0}"
        elif alg == 'ucb':
            for eta in self.etas:
                out[f"ucb_{eta:0.1f}"] = f"ucb_{eta:0.1f}"
        elif alg == 'egreedy':
            for eps in self.epss:
                out[f"egreedy_{eps:0.2f}"] = f"egreedy_{eps:0.2f}"
        return (out)

    def load_results(self):
        all_results = {}
        failures = {}
        alg_names = []

        for alg in self.baselines:
            sub_names = self.get_baseline_configs(alg)
            for (k,v) in sub_names.items():
                alg_names.append(k)
                all_results[k] = []
                dir_name = CotDataLoader.base_dirs[0]+f"{CotDataLoader.baseline_pref}_K={self.K}_T={self.T}_delta={self.delta}/"
                for rep in range(self.reps):
                    fname = dir_name + f"{v}_{rep}.pkl"
                    if os.path.isfile(fname):
                        data = pickle.load(open(fname,'rb'))
                        all_results[k].append(data)

        for alg in self.llms:
            for config in self.configs:
                short = postprocess.util.get_config_shorthand(config)
                alg_names.append(f"{alg}_{short}")
                # alg_names.append(f"{short}")
                all_results[f"{alg}_{short}"] = []
                failures[f"{alg}_{short}"] = {'reverted': [], 'failed': []}
                for base_dir in CotDataLoader.base_dirs:
                    dir_name = base_dir+f"{self.llm_pref}_{config['suggestive']}_{config['summarized']}_{config['dist']}_{config['cot']}_K={self.K}_T={self.T}_delta={self.delta}/"
                    fpref = dir_name + f"{alg}_t={config['temp']}_"
                    for rep in range(self.reps):
                        fname = fpref + f"{rep}.pkl"
                        if os.path.isfile(fname):
                            data = pickle.load(open(fname,'rb'))
                            if len(data) == 3:
                                all_results[f"{alg}_{short}"].append(data[0])
                                failures[f"{alg}_{short}"]['reverted'].append(data[1])
                                failures[f"{alg}_{short}"]['failed'].append(data[2])
                            else:
                                all_results[f"{alg}_{short}"].append(data)
        self.alg_names = alg_names
        # print("self.alg_names: ", self.alg_names)
        self.all_results = all_results
        self.failures = failures

    def summarize(self):
        print(f"Delta: {self.delta}")
        for alg in self.alg_names:
            num_runs = len(self.all_results[alg])
            string = f"{alg}: Loaded {num_runs}"
            if alg in self.failures:
                num_reverted = len([x for x in self.failures[alg]['reverted'] if x > 0])
                num_failed = len([x for x in self.failures[alg]['failed'] if x is True])
                string += f", Reverted {num_reverted}, Failed {num_failed}"
            string += "."
            print(string, flush=True)

