import numpy as np

import pickle, os, plotting
import itertools

def build_table(meta_data):
    all_results = plotting.get_baseline_results(meta_data)
    output = []
    row_names = []
    col_names = [f"MedRew@{meta_data['mr_time_steps'][0]}",
                 f"MuRew@{meta_data['mr_time_steps'][0]}",
                 f"MedP@{meta_data['mr_time_steps'][0]}",
                 f"MuP@{meta_data['mr_time_steps'][0]}",
                 f"MedRew@{meta_data['mr_time_steps'][1]}",
                 f"MuRew@{meta_data['mr_time_steps'][1]}",
                 f"MedP@{meta_data['mr_time_steps'][1]}",
                 f"MuP@{meta_data['mr_time_steps'][1]}",
                 f"OptCt@{meta_data['fail_thresholds'][0]}",
                 f"OptCc@{meta_data['fail_thresholds'][1]}",
                 f"SufFail@{meta_data['sf_time_steps'][0]}",
                 f"SufFail@{meta_data['sf_time_steps'][1]}",
                 f"Revs",
                 f"Fails",
                 f"Runs"]

    for delta in meta_data['deltas']:
        print(f"delta = {delta:0.1f}", flush=True)
        for alg in ['ts','ucb','g1']:
            data = all_results[alg][f"{delta:0.1f}"]
            (mrs, murs, mps, mups, opt_counts, sfs, reversions, failures, runs) = get_and_print_results(alg, delta, data, None, meta_data)
            output.append((mrs[0], murs[0], mps[0], mups[0], mrs[1], murs[1], mps[1], mups[1], opt_counts[0], opt_counts[1], sfs[0], sfs[1], reversions, failures, runs))
            row_names.append(plotting.get_name(alg,meta_data))

        for options in itertools.product(['neu', 'sug'], ['raw', 'sum'], ['cot','not'], [0,1,'dist']):
            meta_data['suggestive'] = options[0]
            meta_data['summarized'] = options[1]
            meta_data['cot'] = options[2]
            if options[3] == 'dist':
                meta_data['temp'] = 0
                meta_data['dist'] = 'dist'
            else:
                meta_data['temp'] = options[3]
                meta_data['dist'] = 'uni'
            fail_mat = plotting.load_llm_results(all_results,meta_data,debug=False)
            for alg in ['gpt35','gpt4','llama13b']:
                if alg not in meta_data['algs']:
                    continue
                if len(all_results[alg][f"{delta:0.1f}"]) > 0:
                    data = all_results[alg][f"{delta:0.1f}"]
                    (mrs, murs, mps, mups, opt_counts, sfs, reversions, failures, runs) = get_and_print_results(alg, delta, data, fail_mat, meta_data)
                    print([len(x) for x in data])
                    output.append((mrs[0], murs[0], mps[0], mups[0], mrs[1], murs[1], mps[1], mups[1], opt_counts[0], opt_counts[1], sfs[0], sfs[1], reversions, failures, runs))
                    row_names.append(plotting.get_name(alg,meta_data))
    return (output, row_names, col_names)

def get_and_print_results(alg, delta, data, fail_mat, meta_data):
    opt_counts = get_opt_counts(data, meta_data)
    if fail_mat is None:
        (reversions, failures) = (0,0)
    else:
        (reversions, failures) = get_failures(data, fail_mat[alg][f"{delta:0.1f}"], meta_data)
    runs = get_runs(data, meta_data)
    if failures == 1.0:
        return ((None,None), (None,None), (None,None), (None,None), (None,None), (None,None), reversions, failures, runs)
    (mrs, mps) = get_median_rews(data, delta, meta_data)
    (murs, mups) = get_mean_rews(data, delta, meta_data)
    sfs = get_suffix_failures(data, meta_data)
    print(f"{plotting.get_name(alg,meta_data)}: OC@{meta_data['fail_thresholds'][0]}: {opt_counts[0]:0.2f}, OC@{meta_data['fail_thresholds'][1]}: {opt_counts[1]:0.2f}, reverted_rounds_fraction: {reversions:0.2f}, fail_fraction: {failures:0.2f}, total_runs {runs}", flush=True)
    print(f"MR@{meta_data['mr_time_steps'][0]}: {mrs[0]:0.2f}, MR@{meta_data['mr_time_steps'][1]}: {mrs[1]:0.2f}", flush=True)
    print(f"SF@{meta_data['sf_time_steps'][0]}: {sfs[0]:0.2f}, SF@{meta_data['sf_time_steps'][1]}: {sfs[1]:0.2f}", flush=True)
    return (mrs, murs, mps, mups, opt_counts, sfs, reversions, failures, runs)
    
def get_opt_counts(data, meta_data):
    ps = meta_data['fail_thresholds']
    arm_counts = []
    for item in data:
        arr = []
        for i in range(len(item)):
            tmp = np.zeros(meta_data['K'])
            tmp[item[i][0]] = 1
            arr.append(tmp)
        if len(arr) == 0:
            counts = np.zeros(meta_data['K'])
        else:
            arr = np.vstack(arr)
            counts = np.sum(arr, axis=0)
        arm_counts.append(counts)
    arm_counts = np.vstack(arm_counts)
    opt_counts = arm_counts[:,0]
    counts = []
    for p in ps:
        count = 0
        for i in range(p):
            count += np.where(opt_counts == i)[0].shape[0]
        counts.append(count/len(data))
    return(counts)

def get_failures(data, fail_mat, meta_data):
    failures = 0
    reversions = 0
    if fail_mat is not None:
        failures = len([x for x in fail_mat['failed'] if x is True])/len(data)
        reversions = np.mean([float(x)/meta_data['T'] for x in fail_mat['reverted']])
    return (reversions, failures)

def get_runs(data, meta_data):
    return len(data)

def get_median_rews(data, delta, meta_data):
    time_steps = meta_data['mr_time_steps']
    rewards = plotting.get_reward_matrix(data)
    if rewards.shape[1] != meta_data['T']:
        return (None, None)
    cum_rewards = np.cumsum(rewards, axis=1)
    cum_meds = np.median(cum_rewards, axis=0)
    time_ave_meds = cum_meds/np.arange(1, meta_data['T']+1)
    to_ret_1 = []
    for t in time_steps:
        m = time_ave_meds[t-1]
        low = 0.5-delta/2; high = 0.5+delta/2
        to_ret_1.append((m - low)/(high-low)) ## Normalize to range [0,1]

    M = plotting.get_opt_action_matrix(data)
    M = M*delta + (0.5-delta/2)
    cum_rewards = np.cumsum(M,axis=1)
    cum_meds = np.median(cum_rewards,axis=0)
    time_ave_meds = cum_meds/np.arange(1, meta_data['T']+1)
    to_ret_2 = []
    for t in time_steps:
        m = time_ave_meds[t-1]
        low = 0.5-delta/2; high = 0.5+delta/2
        to_ret_2.append((m - low)/(high-low)) ## Normalize to range [0,1]
    return (to_ret_1, to_ret_2)

def get_mean_rews(data, delta, meta_data):
    time_steps = meta_data['mr_time_steps']
    rewards = plotting.get_reward_matrix(data)
    if rewards.shape[1] != meta_data['T']:
        return (None, None)
    cum_rewards = np.cumsum(rewards, axis=1)
    cum_means = np.mean(cum_rewards, axis=0)
    time_ave_means = cum_means/np.arange(1, meta_data['T']+1)
    to_ret_1 = []
    for t in time_steps:
        m = time_ave_means[t-1]
        low = 0.5-delta/2; high = 0.5+delta/2
        to_ret_1.append((m - low)/(high-low)) ## Normalize to range [0,1]

    M = plotting.get_opt_action_matrix(data)
    M = M*delta + (0.5-delta/2)
    cum_rewards = np.cumsum(M,axis=1)
    cum_means = np.mean(cum_rewards,axis=0)
    time_ave_means = cum_means/np.arange(1, meta_data['T']+1)
    to_ret_2 = []
    for t in time_steps:
        m = time_ave_means[t-1]
        low = 0.5-delta/2; high = 0.5+delta/2
        to_ret_2.append((m - low)/(high-low)) ## Normalize to range [0,1]
    return (to_ret_1, to_ret_2)

def get_suffix_failures(data, meta_data):
    time_steps = meta_data['sf_time_steps']
    opt_acts = plotting.get_opt_action_matrix(data)
    arr = np.fliplr(np.cumsum(np.fliplr(opt_acts), axis=1))
    bools = (arr==0)
    sfs = np.mean(bools, axis=0)
    to_ret = []
    for t in time_steps:
        to_ret.append(sfs[t-1])
    return (to_ret)
