import numpy as np
import itertools
import os
import pickle
import postprocess.util

baselines = ['ts', 'ucb', 'greedy', 'egreedy']
llms = ['gpt35', 'gpt4', 'llama13b']

class DataLoader(object):

    base_dir = "../results/"
    baseline_pref = 'baselines'

    f_sub_names = {
        'ts': 'ts',
        'ucb': 'ucb',
        'greedy': 'greedy',
        'egreedy': 'egreedy',
        'gpt35': 'gpt35',
        'gpt4': 'gpt4',
        'llama13b': 'llama13b'
    }

    def all_configs():
        configs = []
        for item in itertools.product(['neu', 'sug'], ['raw', 'sum'], ['not', 'cot', 'cotn'], ['0', '1', 'dist']):
            config = {'suggestive': item[0],
                      'summarized': item[1],
                      'cot': item[2]}
            if item[3] == 'dist':
                config['dist'] = 'dist'
                config['temp'] = 0
            else:
                config['dist'] = 'uni'
                config['temp'] = item[3]
            configs.append(config)
        return (configs)

    def __init__(self, meta_data):
        self.T = meta_data['T']
        self.K = meta_data['K']
        self.delta = f"{meta_data['delta']:0.1f}"
        self.reps = meta_data['reps']
        self.algs = meta_data['algs']
        self.llm_pref = meta_data['llm_pref']
        self.etas = meta_data['etas']  ## UCB hyperparameter
        self.n0s = meta_data['n0s']    ## Greedy hyperparameter
        self.epss = meta_data['epss']
        if 'configs' in meta_data.keys():
            self.configs = meta_data['configs']
        else:
            self.configs = DataLoader.all_configs()

        self.baselines = [x for x in baselines if x in self.algs]
        self.llms = [x for x in llms if x in self.algs]

        self.load_results()
        self.summarize()


    def get_baseline_configs(self,alg):
        out = {}
        if alg == 'ts':
            out['ts'] = 'ts'
        elif alg == 'greedy':
            for n0 in self.n0s:
                out[f"greedy_{n0}"] = f"greedy_{n0}"
        elif alg == 'ucb':
            for eta in self.etas:
                out[f"ucb_{eta:0.1f}"] = f"ucb_{eta:0.1f}"
        elif alg == 'egreedy':
            for eps in self.epss:
                out[f"egreedy_{eps:0.2f}"] = f"egreedy_{eps:0.2f}"
        return (out)
                
    def load_results(self):
        all_results = {}
        failures = {}
        alg_names = []

        ## Load baseline results
        for alg in self.baselines:
            sub_names = self.get_baseline_configs(alg)
            for (k,v) in sub_names.items():
                alg_names.append(k)
                all_results[k] = []
                dir_name = DataLoader.base_dir+f"{DataLoader.baseline_pref}_K={self.K}_T={self.T}_delta={self.delta}/"
                for rep in range(self.reps):
                    fname = dir_name + f"{v}_{rep}.pkl"
                    if os.path.isfile(fname):
                        data = pickle.load(open(fname,'rb'))
                        all_results[k].append(data)

        for alg in self.llms:
            for config in self.configs:
                short = postprocess.util.get_config_shorthand(config)
                alg_names.append(f"{alg}_{short}")
                # alg_names.append(f"{short}")
                all_results[f"{alg}_{short}"] = []
                failures[f"{alg}_{short}"] = {'reverted': [], 'failed': []}
                dir_name = DataLoader.base_dir+f"{self.llm_pref}_{config['suggestive']}_{config['summarized']}_{config['dist']}_{config['cot']}_K={self.K}_T={self.T}_delta={self.delta}/"
                fpref = dir_name + f"{alg}_t={config['temp']}_"
                for rep in range(self.reps):
                    fname = fpref + f"{rep}.pkl"
                    if os.path.isfile(fname):
                        data = pickle.load(open(fname,'rb'))
                        if len(data) == 3:
                            all_results[f"{alg}_{short}"].append(data[0])
                            failures[f"{alg}_{short}"]['reverted'].append(data[1])
                            failures[f"{alg}_{short}"]['failed'].append(data[2])
                        else:
                            all_results[f"{alg}_{short}"].append(data)
        self.alg_names = alg_names
        # print("self.alg_names: ", self.alg_names)
        self.all_results = all_results
        self.failures = failures

    def summarize(self):
        print(f"Delta: {self.delta}")
        for alg in self.alg_names:
            num_runs = len(self.all_results[alg])
            string = f"{alg}: Loaded {num_runs}"
            if alg in self.failures:
                num_reverted = len([x for x in self.failures[alg]['reverted'] if x > 0])
                num_failed = len([x for x in self.failures[alg]['failed'] if x is True])
                string += f", Reverted {num_reverted}, Failed {num_failed}"
            string += "."
            print(string, flush=True)
                
                
    def get_reward_matrix(self,alg):
        data = self.all_results[alg]
        rewards = []
        for item in data:
            if len(item) == self.T:
                rewards.append([x[1] for x in item])
        if len(rewards) > 0:
            return np.vstack(rewards)
        return None

    def get_pseudoreward_matrix(self,alg):
        data = self.all_results[alg]
        rewards = []
        for item in data:
            if len(item) == self.T:
                rewards.append([0.5+float(self.delta)/2 if x[0] == 0 else 0.5-float(self.delta)/2 for x in item])
        if len(rewards) > 0:
            return np.vstack(rewards)
        return None
        
    def get_opt_action_matrix(self,alg):
        data = self.all_results[alg]
        rewards = []
        for item in data:
            if len(item) == self.T:
                rewards.append([1 if x[0] == 0 else 0 for x in item])
        if len(rewards) > 0:
            return np.vstack(rewards)
        return None

    def get_action_tensor(self,alg):
        ## returns a tensor of size (reps, T, K) 
        data = self.all_results[alg]
        tens = np.zeros((len(data), self.T, self.K))
        inds_to_keep = []
        for i in range(len(data)):
            if len(data[i]) < self.T:
                continue
            for t in range(self.T):
                tens[i,t,data[i][t][0]] = 1
            inds_to_keep.append(i)
        return (tens[np.array(inds_to_keep),:,:])

    def get_action_freqs(self,alg):
        data = self.all_results[alg]
        arm_counts = []
        for item in data:
            if len(item) == self.T:
                arr = []
                for i in range(len(item)):
                    tmp = np.zeros(self.K)
                    tmp[item[i][0]] = 1
                    arr.append(tmp)
                arr = np.vstack(arr)
                arm_counts.append(np.sum(arr,axis=0))
        if len(arm_counts) > 0:
            return np.vstack(arm_counts)
        return None

    def get_suffix_counts(self,alg):
        arr = self.get_opt_action_matrix(alg)
        if arr is None:
            return None
        arr = np.fliplr(np.cumsum(np.fliplr(arr),axis=1))
        bools = (arr==0)
        return bools

    def get_mean_std(self,alg,pseudo=False):
        if pseudo:
            rewards = self.get_pseudoreward_matrix(alg)
        else:
            rewards = self.get_reward_matrix(alg)
        if rewards is None:
            return None
        (N,T) = rewards.shape
        cum_rewards = np.cumsum(rewards,axis=1)
        means = np.mean(cum_rewards,axis=0)
        stds = np.std(cum_rewards,axis=0)
        return (means, means-2*stds/np.sqrt(N), means+2*stds/np.sqrt(N))

    def get_median_quantile(self,alg,q,pseudo=False):
        if pseudo:
            rewards = self.get_pseudoreward_matrix(alg)
        else:
            rewards = self.get_reward_matrix(alg)
        if rewards is None:
            return None
        (N,T) = rewards.shape
        cum_rewards = np.cumsum(rewards,axis=1)
        meds = np.median(cum_rewards,axis=0)
        qlow = np.quantile(cum_rewards,q,axis=0)
        qhigh = np.quantile(cum_rewards,1-q,axis=0)
        return (meds, qlow, qhigh)

    def get_scatter_data(self,x_fn,y_fn):
        output = {}
        for (k,v) in self.all_results.items():
            xpts = []
            ypts = []
            ## Loop over replicates
            for item in v:
                xpts.append(x_fn(item))
                ypts.append(y_fn(item))
            output[k] = (xpts,ypts)
        return(output)
        
class PuzzlesDataLoader(object):
    base_dir = "../results/"
    baseline_pref = "puzzles_baselines"


    def __init__(self,meta_data):
        self.T = meta_data['T']
        self.K = meta_data['K']
        self.deltas = meta_data['deltas']
        self.ts = meta_data['ts']
        self.reps = meta_data['reps']
        self.algs = meta_data['algs']
        self.llm_pref = meta_data['llm_pref']
        self.etas = meta_data['etas']
        self.n0s = meta_data['n0s']
        self.epss = meta_data['epss']
        self.configs = DataLoader.all_configs()
        self.histories = meta_data['histories']

        self.baselines = [x for x in baselines if x in self.algs]
        self.baseline_configs = self.get_baseline_configs()
        self.llms = [x for x in llms if x in self.algs]

        self.load_results()

    def get_baseline_configs(self):
        out = []
        for b_alg in self.baselines:
            if b_alg == 'ts':
                out.append('ts')
            elif b_alg == 'ucb':
                out.extend([f'ucb_{eta:0.1f}' for eta in self.etas])
            elif b_alg == 'greedy':
                out.extend([f'greedy_{n0}' for n0 in self.n0s])
            elif b_alg == 'egreedy':
                out.extend([f'egreedy_{eps:0.1f}' for eps in self.epss])
        return out

    def load_results(self):
        all_results = {}

        for alg in self.baseline_configs:
            all_results[alg] = {}
            for (t,delta,history) in itertools.product(self.ts, self.deltas,self.histories):
                fname = PuzzlesDataLoader.base_dir+f"{PuzzlesDataLoader.baseline_pref}_K={self.K}_T={self.T}_t={t}_delta={delta:0.1f}_{history}/{alg}.pkl"
                if os.path.isfile(fname):
                    data = pickle.load(open(fname, 'rb'))
                    all_results[alg][f'{t}_{delta:0.1f}_{history}'] = data

        for (alg,config) in itertools.product(self.llms,self.configs):
            short = postprocess.util.get_config_shorthand(config)
            all_results[f"{alg}_{short}"] = {}
            for (t,delta,history) in itertools.product(self.ts, self.deltas,self.histories):
                fname = PuzzlesDataLoader.base_dir+f"puzzles_{self.llm_pref}_{config['suggestive']}_{config['summarized']}_{config['dist']}_{config['cot']}_K={self.K}_T={self.T}_t={t}_delta={delta:0.1f}_{history}/{alg}_t={config['temp']}.pkl"
                if os.path.isfile(fname):
                    data = pickle.load(open(fname, 'rb'))
                    all_results[f"{alg}_{short}"][f'{t}_{delta:0.1f}_{history}'] = data

        self.all_results = all_results


    def get_count_ranking(self,alg,history):
        mat = np.zeros((len(self.ts), len(self.deltas)))
        for i in range(len(self.ts)):
            for j in range(len(self.deltas)):
                score = 0; denom = 0
                for tup in self.all_results[alg][f'{self.ts[i]}_{self.deltas[j]:0.1f}_{history}']:
                    rank = postprocess.util.rank(tup[1],tup[1][tup[2]])
                    score += rank; denom += 1
                mat[i,j] = np.mean(np.float(score)/denom)

        ## mat has entries between 1 and self.K.
        ## We standardize to get entries between 0 and 1.
        return (mat-1)/(self.K-1)
        

    def get_conditional_maxcount_alignment(self,alg,history):
        mat = np.zeros((len(self.ts), len(self.deltas)))
        for i in range(len(self.ts)):
            for j in range(len(self.deltas)):
                max_count = [np.flatnonzero(x[1] == x[1].max()) for x in self.all_results[alg][f'{self.ts[i]}_{self.deltas[j]:0.1f}_{history}']]
                emp_best = [np.flatnonzero(x[0] == x[0].max()) for x in self.all_results[alg][f'{self.ts[i]}_{self.deltas[j]:0.1f}_{history}']]
                choice = [x[2] for x in self.all_results[alg][f'{self.ts[i]}_{self.deltas[j]:0.1f}_{history}']]
                hits = 0.; denom = 0.
                for s in range(len(choice)):
                    if choice[s] in emp_best[s]:
                        continue
                    elif choice[s] in max_count[s]:
                        hits += 1
                    denom += 1
                # val = [1 if choice[i] in max_count[i] or choice[i] in emp_best[i] else 0 for i in range(len(choice))]
                mat[i,j] = hits/denom if denom != 0 else None
        return mat

    def get_absolute_maxcount_alignment(self,alg,history):
        mat = np.zeros((len(self.ts), len(self.deltas)))
        for i in range(len(self.ts)):
            for j in range(len(self.deltas)):
                max_count = [np.flatnonzero(x[1] == x[1].max()) for x in self.all_results[alg][f'{self.ts[i]}_{self.deltas[j]:0.1f}_{history}']]
                emp_best = [np.flatnonzero(x[0] == x[0].max()) for x in self.all_results[alg][f'{self.ts[i]}_{self.deltas[j]:0.1f}_{history}']]
                choice = [x[2] for x in self.all_results[alg][f'{self.ts[i]}_{self.deltas[j]:0.1f}_{history}']]
                hits = 0.
                for s in range(len(choice)):
                    if choice[s] in emp_best[s]:
                        continue
                    elif choice[s] in max_count[s]:
                        hits += 1
                # val = [1 if choice[i] in max_count[i] or choice[i] in emp_best[i] else 0 for i in range(len(choice))]
                mat[i,j] = hits/len(choice)
        return mat

    def get_history_by_algs_mat(self,t,delta,fn):
        output = []
        for alg in self.all_results.keys():
            if len(self.all_results[alg]) == 0:
                continue
            arr = []
            for history in self.histories:
                data = self.all_results[alg][f'{t}_{delta:0.1f}_{history}']
                arr.append(fn(data))
            output.append(np.array(arr))
        return np.vstack(output)
    
    def get_ts_by_deltas_mat(self, alg, history, fn):
        mat = np.zeros((len(self.ts), len(self.deltas)))
        for i in range(len(self.ts)):
            for j in range(len(self.deltas)):
                data = self.all_results[alg][f'{self.ts[i]}_{self.deltas[j]:0.1f}_{history}']
                mat[i,j] = fn(data)
        return mat
                
