print('Started')
import os
import json
from run_leo import LEO_METHODS, IMPRESSION_FNS, improve
from scipy import stats
import numpy as np
import itertools

def get_cum_stats(report_array):
    init_scores = np.array([x['init_mean'] for x in report_array])
    final_scores = np.array([x['final_mean'] for x in report_array])

    init_all_scores = np.array([x['init_scores'] for x in report_array])
    final_all_scores = np.array([x['final_scores'] for x in report_array])

    init_stds = np.array([x['init_std'] for x in report_array])
    final_stds = np.array([x['final_std'] for x in report_array])

    p_value = stats.ttest_rel(init_scores, final_scores).pvalue

    zero_indices = np.where(np.array(final_scores) == 0)[0].tolist() + np.where(np.array(init_scores) == 0)[0].tolist()
    l1 = [x for i, x in enumerate(final_scores) if i not in zero_indices]
    l2 = [x for i, x in enumerate(init_scores) if i not in zero_indices]

    return {
            'query': 'Overall',
            'dset' : 'Overall',
            'category' : 'Overall',
            'tags' : 'Overall',
            'increment' : final_scores.mean() - init_scores.mean(),
            'init_mean' : init_scores.mean(),
            'init_std' : init_stds.mean(),
            'final_final_mean' : np.array(final_all_scores).mean(),
            'final_final_std' : np.array(final_all_scores).mean(axis = 0).std(),
            'init_init_mean' : np.array(init_all_scores).mean(),
            'init_init_std' : np.array(init_all_scores).mean(axis = 0).std(),
            'final_mean' : final_scores.mean(),
            'final_std' : final_stds.mean(),
            'significant_change' : str(p_value < 0.05),
            'init_scores' : init_scores.tolist(),
            'final_scores' : final_scores.tolist(),
            'nonzero_p_value' : stats.ttest_rel(l2, l1).pvalue,
            'p_value' : p_value,
        }


IDX = -1
IMPRESSION_METHOD = 'line_para_score'
IMPRESSION_METHOD = 'simple_wordpos'
IMPRESSION_METHOD = 'simple_word'
IMPRESSION_METHOD = 'simple_pos'
IMPRESSION_METHOD = 'subjpos_detailed'
# IMPRESSION_METHOD = 'diversity_detailed'
# IMPRESSION_METHOD = 'uniqueness_detailed'
IMPRESSION_METHOD = 'follow_detailed'
# IMPRESSION_METHOD = 'influence_detailed'
# IMPRESSION_METHOD = 'relevance_detailed'
# IMPRESSION_METHOD = 'subjcount_detailed'
IMPRESSION_METHOD = 'subjective_score'

import sys
if len(sys.argv) >=2:
    DSET = sys.argv[1]
else:
    DSET = 'perplexity_28'

if len(sys.argv) == 4:
    a = int(sys.argv[2])
    b = int(sys.argv[3])
MODEL = 'gpt-3.5-turbo-16k'

os.makedirs(f'reports_bench/{DSET}/{MODEL}/{IMPRESSION_METHOD}/idx-{IDX}', exist_ok=True)


if len(sys.argv) == 4:
    dt = [json.loads(x) for x in open(f'{DSET}')][a:b]#[:5]
else:
    dt = [json.loads(x) for x in open(f'{DSET}')]

all_init_scores = []
all_final_scores = []

indices = [2, 4, 3, 2, 2, 0, 0, 2, 2, 3, 3, 3, 2, 4, 3, 4, 0, 0, 4, 1, 3, 4, 0, 2, 4, 2, 0, 4, 1, 4, 2, 3, 1, 1, 1, 4, 2, 1, 1, 1, 1, 3, 1, 4, 3, 2, 4, 3, 1, 0, 3, 2, 3, 3, 2, 4, 1, 4, 3, 3, 1, 0, 3, 0, 0, 0, 0, 0, 1, 0, 0, 2, 2, 3, 4, 2, 0, 1, 1, 4]
# indices = indices[:80] + indices[:80] + indices[:50] + indices[:48] + indices[:52]
indices = indices[:80] + indices[:80] + indices[:80] + indices[:80] + indices[:80]

# dt = dt[:50]
# indices = indices[:50]
# dt = dt[:165] + dt[171:]
# indices = indices[:165] + indices[171:]
# dt = dt[167:171]
# indices = indices[167:171]

# dt = dt[:120]
# indices = [0 for _ in range(len(dt))]



    
for dt_idx, k in enumerate(dt):
    IDX = indices[dt_idx]
    init_scores, final_scores = improve(k['query'], idx= IDX, impression_fn=IMPRESSION_FNS[IMPRESSION_METHOD], returnFullData=True, static_cache=os.environ.get('STATIC_CACHE', None)=='True')
    all_init_scores.append(init_scores)
    all_final_scores.append(final_scores)

# We have all the data.
leo_methods = list(LEO_METHODS.keys())


for i, meth in enumerate(leo_methods):
    report_array = []
    for dt_idx, x in enumerate(range(len(dt))):
        IDX = indices[dt_idx]
        init_scores = all_init_scores[x][:, IDX]
        relevant_scores = all_final_scores[x][i][:, IDX]
        p_value = stats.ttest_ind(init_scores, relevant_scores).pvalue

        # from pdb import set_trace; set_trace()


        zero_indices = np.where(np.array(relevant_scores) == 0)[0].tolist() + np.where(np.array(init_scores) == 0)[0].tolist()

        l1 = [x for i, x in enumerate(relevant_scores) if i not in zero_indices]
        l2 = [x for i, x in enumerate(init_scores) if i not in zero_indices]


        # Above are one-dimensional arrays of scores.
        report_array.append(
            {
                'query': dt[x]['query'],
                'dset' : dt[x]['dset'],
                'category' : dt[x]['category'],
                'tags' : dt[x]['tags'],
                'increment' : relevant_scores.mean() - init_scores.mean(),
                'init_mean' : init_scores.mean(),
                'init_std' : init_scores.std(),
                'final_mean' : relevant_scores.mean(),
                'final_std' : relevant_scores.std(),
                'significant_change' : str(p_value < 0.05),
                'init_scores' : init_scores.tolist(),
                'final_scores' : relevant_scores.tolist(),
                'nonzero_p_value' : stats.ttest_ind(l2, l1).pvalue,
                'p_value' : p_value,
            }
        )

    # Now overall stats.
    indiv_report_array = report_array.copy()
    
    all_dsets = {x['dset'] for x in dt}
    all_cats = {x['category'] for x in dt}
    all_tags = set(itertools.chain(*[x['tags'] for x in dt]))

    report_array.append(get_cum_stats(indiv_report_array))
    for dset in all_dsets:
        dset_stats = get_cum_stats([x for x in indiv_report_array if x['dset'] == dset])
        dset_stats['dset'] = dset
        report_array.append(dset_stats)

    for cat in all_cats:
        cat_stats = get_cum_stats([x for x in indiv_report_array if cat in x['category']])
        cat_stats['category'] = cat
        report_array.append(cat_stats)
    
    print(all_tags)
    for tag in all_tags:
        tag_stats = get_cum_stats([x for x in indiv_report_array if tag in x['tags']])
        tag_stats['tags'] = [tag]
        report_array.append(tag_stats)

    os.makedirs(f'reports_bench/{DSET}/{MODEL}/{IMPRESSION_METHOD}/idx-{IDX}', exist_ok=True)
    with open(f'reports_bench/{DSET}/{MODEL}/{IMPRESSION_METHOD}/idx-{IDX}/{meth}.json', 'w') as f:
        json.dump(report_array, f, indent=2)

