import sys
# sys.path.append('/home/mila/x/xiyuan.zou/research/icl-reverse-cot')
# import hydra
# from omegaconf import DictConfig
# from utils import load_dataset, chunks, process_answer
import torch
import transformers
from transformers import AutoTokenizer, AutoModel, GPTJForCausalLM, AutoModelForCausalLM
# import hydra
# from omegaconf import DictConfig, OmegaConf
import argparse
import os
import ast
from scipy import stats

def main(input_file):
    res = []
    with open(input_file, 'r') as reader:
        input_lines = reader.readlines()
    for i, each_line in enumerate(input_lines):
        if i > 1 and input_lines[i].startswith("[0."):
            # list
            res.append(ast.literal_eval(each_line))
    
    res = res[-24:]

    cont = res[0::6]
    temp_stop = res[1::6]
    temp = res[2::6]
    cont_stop = res[3::6]
    stop = res[4::6]
    temp_cont = res[5::6]

    models = ["3b", "7b", "13b", "30b"]
    for i in range(4):
        print("This is the model ", models[i])
        t_results = stats.ttest_rel(temp[i], cont[i])
        # correct for one sided test
        pval = float(t_results[1]) / 2
        print("temp <-> cont = ", float(pval))
        if (float(pval) <= float(0.05)):
            print("\nTest result is significant with p-value: {}".format(pval))
        else:
            print("\nTest result is not significant with p-value: {}".format(pval))

        print()

        t_results = stats.ttest_rel(temp[i], stop[i])
        # correct for one sided test
        pval = float(t_results[1]) / 2
        print("temp <-> stop = ", float(pval))
        if (float(pval) <= float(0.05)):
            print("\nTest result is significant with p-value: {}".format(pval))
        else:
            print("\nTest result is not significant with p-value: {}".format(pval))
        print()

        t_results = stats.ttest_rel(cont[i], stop[i])
        # correct for one sided test
        pval = float(t_results[1]) / 2
        print("cont <-> stop = ", float(pval))
        if (float(pval) <= float(0.05)):
            print("\nTest result is significant with p-value: {}".format(pval))
        else:
            print("\nTest result is not significant with p-value: {}".format(pval))
        print()

        print("------------------------------------")
        t_results = stats.ttest_rel(temp_cont[i], cont_stop[i])
        # correct for one sided test
        pval = float(t_results[1]) / 2
        print("temp_cont <-> cont_stop = ", float(pval))
        if (float(pval) <= float(0.05)):
            print("\nTest result is significant with p-value: {}".format(pval))
        else:
            print("\nTest result is not significant with p-value: {}".format(pval))
        print()

        t_results = stats.ttest_rel(temp_stop[i], cont_stop[i])
        # correct for one sided test
        pval = float(t_results[1]) / 2
        print("temp_stop <-> cont_stop = ", float(pval))
        if (float(pval) <= float(0.05)):
            print("\nTest result is significant with p-value: {}".format(pval))
        else:
            print("\nTest result is not significant with p-value: {}".format(pval))
        print()

        t_results = stats.ttest_rel(temp_cont[i], temp_stop[i])
        # correct for one sided test
        pval = float(t_results[1]) / 2
        print("temp_cont <-> temp_stop = ", float(pval))
        if (float(pval) <= float(0.05)):
            print("\nTest result is significant with p-value: {}".format(pval))
        else:
            print("\nTest result is not significant with p-value: {}".format(pval))
        print()



        t_results = stats.ttest_rel(temp[i], cont[i])
        # correct for one sided test
        pval = float(t_results[1]) / 2
        print(float(pval))

        t_results = stats.ttest_rel(temp[i], stop[i])
        # correct for one sided test
        pval = float(t_results[1]) / 2
        # print("temp <-> stop = ", float(pval))
        print(float(pval))

        t_results = stats.ttest_rel(cont[i], stop[i])
        # correct for one sided test
        pval = float(t_results[1]) / 2
        # print("cont <-> stop = ", float(pval))
        print(float(pval))

        print("------------------------------------")
        t_results = stats.ttest_rel(temp_cont[i], cont_stop[i])
        # correct for one sided test
        pval = float(t_results[1]) / 2
        # print("temp_cont <-> cont_stop = ", float(pval))
        print(float(pval))

        t_results = stats.ttest_rel(temp_stop[i], cont_stop[i])
        # correct for one sided test
        pval = float(t_results[1]) / 2
        # print("temp_stop <-> cont_stop = ", float(pval))
        print(float(pval))

        t_results = stats.ttest_rel(temp_cont[i], temp_stop[i])
        # correct for one sided test
        pval = float(t_results[1]) / 2
        # print("temp_cont <-> temp_stop = ", float(pval))
        print(float(pval))
        print()
        print()
        print()
    # res2 = []
    # with open(input_file2, 'r') as reader:
    #     input_lines2 = reader.readlines()
    # for i, each_line in enumerate(input_lines2):
    #     if i > 1 and input_lines[i].startswith("[0."):
    #         res2.append(each_line)
    # with open(output_file + '3b', 'w') as writer:
    #     for each in outputs:
    #         writer.write(each)

    # with open(output_file + '7b', 'w') as writer:
    #     for each in outputs:
    #         writer.write(each)
    # with open(output_file + '13b', 'w') as writer:
    #     for each in outputs:
    #         writer.write(each)
    # with open(output_file + '30b', 'w') as writer:
    #     for each in outputs:
    #         writer.write(each)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_file', dest='input_file', action='store', required=True, help='num training examples to use', type=str)
    # parser.add_argument('--output_file', dest='output_file', action='store', required=True, help='num training examples to use', type=str)
    # parser.add_argument('--dataset', dest='dataset', action='store', required=True, help='num training examples to use', type=str)
    # parser.add_argument('--model', dest='model', action='store', required=True, help='num training examples to use', type=str)
    # parser.add_argument('--model', dest='model', action='store', required=True, help='num training examples to use', type=str)
    # compression_token_initialization

    args = parser.parse_args()
    args = vars(args)

    # simple processing
    # def convert_to_list(items, is_int=False):
    #     if is_int:
    #         return [int(s.strip()) for s in items.split(",")]
    #     else:
    #         return [s.strip() for s in items.split(",")]

    # args['models'] = convert_to_list(args['models'])
    # args['datasets'] = convert_to_list(args['datasets'])
    # args['all_shots'] = convert_to_list(args['all_shots'], is_int=True)

    main(**args)