import os
import pickle
import argparse
import numpy as np
import scipy.stats as stats


def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), stats.sem(a)
    h = se * stats.t.ppf((1 + confidence) / 2., n-1)
    return m, h

parser = argparse.ArgumentParser()
parser.add_argument('-sd', '--source_dir', type=str, default="saved_models_best")
parser.add_argument('-ut', '--unsuitable_type', type=str, default="all")
parser.add_argument('-d', '--dataset', type=str, default="humanact12")
parser.add_argument('--concept_list', nargs="+")
# parser.add_argument('--concept_list', type=int, nargs="+")
parser.add_argument('-mn', '--model_name', type=str, default="action2motion")
parser.add_argument('--experiment_name', type=str, default='test')
parser.add_argument('--num_samples', type=int, default=20)
parser.add_argument('--split', type=str, default='test')
parser.add_argument('--eval_full', action='store_true')

args = parser.parse_args()

humanact12_match = {
    "0": "warm_up",
    "1": "walk",
    "2": "run",
    "3": "jump",
    "4": "drink",
    "5": "lift_dumbbell",
    "6": "sit",
    "7": "eat",
    "8": "turn_steer_wheel",
    "9": "phone",
    "10": "boxing",
    "11": "throw"
}

def main(args):
    if not args.eval_full:
        for concept in args.concept_list:
            if args.dataset == "humanact12":
                results_dir = os.path.join(args.source_dir, args.dataset, humanact12_match[str(concept)], args.experiment_name, args.dataset, args.unsuitable_type, args.model_name)
            elif args.dataset == "uestc":
                results_dir = os.path.join(args.source_dir, args.dataset, concept, args.experiment_name, args.dataset, args.unsuitable_type, args.model_name, args.split)
            gt_accuracy, gt_diversity, gt_multimodality, gen_accuracy, gen_diversity, gen_multimodality, refined_anchor_accuracy, refined_anchor_diversity, refined_anchor_multimodality, refined_accuracy, refined_diversity, refined_multimodality, gen_fid, refined_anchor_fid, refined_fid = [], [], [], [], [], [], [], [], [], [], [], [], [], [], []
            for i in range(args.num_samples):
                sample_dir = os.path.join(results_dir, str(i))
                with open(os.path.join(sample_dir, f"evaluation_results.pkl"), 'rb') as f:
                    results = pickle.load(f)
                gt_accuracy.append(results['gt_accuracy'])
                gt_diversity.append(results['gt_diversity'])
                gt_multimodality.append(results['gt_multimodality'])
                gen_accuracy.append(results['generated_accuracy'])
                gen_diversity.append(results['generated_diversity'])
                gen_multimodality.append(results['generated_multimodality'])
                refined_anchor_accuracy.append(results['refined_anchor_accuracy'])
                refined_anchor_diversity.append(results['refined_anchor_diversity'])
                refined_anchor_multimodality.append(results['refined_anchor_multimodality'])
                refined_accuracy.append(results['refined_accuracy'])
                refined_diversity.append(results['refined_diversity'])
                refined_multimodality.append(results['refined_multimodality'])
                gen_fid.append(results['fid_generated'])
                refined_anchor_fid.append(results['fid_refined_anchor'])
                refined_fid.append(results['fid_refined'])
            if args.dataset == "humanact12":
                print(f"==> Concept: {humanact12_match[str(concept)]}")
            elif args.dataset == "uestc":
                print(f"==> Concept: {concept}")
            # expressed as mean \textsuperscript{$\pm$} std
            print("==> GT Accuracy:", mean_confidence_interval(gt_accuracy)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(gt_accuracy)[1])
            print("==> GT Diversity:", mean_confidence_interval(gt_diversity)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(gt_diversity)[1])
            print("==> GT Multimodality:", mean_confidence_interval(gt_multimodality)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(gt_multimodality)[1])
            print("==> Gen Accuracy:", mean_confidence_interval(gen_accuracy)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(gen_accuracy)[1])
            print("==> Gen Diversity:", mean_confidence_interval(gen_diversity)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(gen_diversity)[1])
            print("==> Gen Multimodality:", mean_confidence_interval(gen_multimodality)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(gen_multimodality)[1])
            print("==> Refined Anchor Accuracy:", mean_confidence_interval(refined_anchor_accuracy)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(refined_anchor_accuracy)[1])
            print("==> Refined Anchor Diversity:", mean_confidence_interval(refined_anchor_diversity)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(refined_anchor_diversity)[1])
            print("==> Refined Anchor Multimodality:", mean_confidence_interval(refined_anchor_multimodality)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(refined_anchor_multimodality)[1])
            print("==> Refined Accuracy:", mean_confidence_interval(refined_accuracy)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(refined_accuracy)[1])
            print("==> Refined Diversity:", mean_confidence_interval(refined_diversity)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(refined_diversity)[1])
            print("==> Refined Multimodality:", mean_confidence_interval(refined_multimodality)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(refined_multimodality)[1])
            print("==> Gen FID:", mean_confidence_interval(gen_fid)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(gen_fid)[1])
            print("==> Refined Anchor FID:", mean_confidence_interval(refined_anchor_fid)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(refined_anchor_fid)[1])
            print("==> Refined FID:", mean_confidence_interval(refined_fid)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(refined_fid)[1])
            # print(f"==> GT Accuracy: {mean_confidence_interval(gt_accuracy)}")
            # print(f"==> GT Diversity: {mean_confidence_interval(gt_diversity)}")
            # print(f"==> GT Multimodality: {mean_confidence_interval(gt_multimodality)}")
            # print(f"==> Gen Accuracy: {mean_confidence_interval(gen_accuracy)}")
            # print(f"==> Gen Diversity: {mean_confidence_interval(gen_diversity)}")
            # print(f"==> Gen Multimodality: {mean_confidence_interval(gen_multimodality)}")
            # print(f"==> Refined Anchor Accuracy: {mean_confidence_interval(refined_anchor_accuracy)}")
            # print(f"==> Refined Anchor Diversity: {mean_confidence_interval(refined_anchor_diversity)}")
            # print(f"==> Refined Anchor Multimodality: {mean_confidence_interval(refined_anchor_multimodality)}")
            # print(f"==> Refined Accuracy: {mean_confidence_interval(refined_accuracy)}")
            # print(f"==> Refined Diversity: {mean_confidence_interval(refined_diversity)}")
            # print(f"==> Refined Multimodality: {mean_confidence_interval(refined_multimodality)}")
            # print(f"==> Gen FID: {mean_confidence_interval(gen_fid)}")
            # print(f"==> Refined Anchor FID: {mean_confidence_interval(refined_anchor_fid)}")
            # print(f"==> Refined FID: {mean_confidence_interval(refined_fid)}")
    else:
        if args.dataset == "humanact12":
            results_dir = os.path.join(args.source_dir, args.dataset, args.model_name)
        elif args.dataset == "uestc":
            pass
        gt_accuracy, gt_diversity, gt_multimodality, gen_accuracy, gen_diversity, gen_multimodality, refined_accuracy, refined_diversity, refined_multimodality, gen_fid, refined_fid = [], [], [], [], [], [], [], [], [], [], []
        for i in range(args.num_samples):
            with open(os.path.join(results_dir, f"evaluation_results_{i}.pkl"), 'rb') as f:
                results = pickle.load(f)
            gt_accuracy.append(results['gt_accuracy'])
            gt_diversity.append(results['gt_diversity'])
            gt_multimodality.append(results['gt_multimodality'])
            gen_accuracy.append(results['generated_accuracy'])
            gen_diversity.append(results['generated_diversity'])
            gen_multimodality.append(results['generated_multimodality'])
            refined_accuracy.append(results['refined_accuracy'])
            refined_diversity.append(results['refined_diversity'])
            refined_multimodality.append(results['refined_multimodality'])
            gen_fid.append(results['fid_generated'])
            refined_fid.append(results['fid_refined'])
            # print("==> Sample: ", i)
            # print(f"==> GT Diversity: {np.mean(gt_diversity)}")
            # print(f"==> Gen Diversity: {np.mean(gen_diversity)}")
            # print(f"==> Refined Diversity: {np.mean(refined_diversity)}")
            # print("--------------------------------------------------")
        print(f"==> Dataset: {args.dataset}, Model: {args.model_name}")
        print("==> GT Accuracy:", mean_confidence_interval(gt_accuracy)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(gt_accuracy)[1])
        print("==> GT Diversity:", mean_confidence_interval(gt_diversity)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(gt_diversity)[1])
        print("==> GT Multimodality:", mean_confidence_interval(gt_multimodality)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(gt_multimodality)[1])
        print("==> Gen Accuracy:", mean_confidence_interval(gen_accuracy)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(gen_accuracy)[1])
        print("==> Gen Diversity:", mean_confidence_interval(gen_diversity)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(gen_diversity)[1])
        print("==> Gen Multimodality:", mean_confidence_interval(gen_multimodality)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(gen_multimodality)[1])
        print("==> Refined Accuracy:", mean_confidence_interval(refined_accuracy)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(refined_accuracy)[1])
        print("==> Refined Diversity:", mean_confidence_interval(refined_diversity)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(refined_diversity)[1])
        print("==> Refined Multimodality:", mean_confidence_interval(refined_multimodality)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(refined_multimodality)[1])
        print("==> Gen FID:", mean_confidence_interval(gen_fid)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(gen_fid)[1])
        print("==> Refined FID:", mean_confidence_interval(refined_fid)[0], "\\textsuperscript{$\pm$}", mean_confidence_interval(refined_fid)[1])
        # print(f"==> GT Accuracy: {mean_confidence_interval(gt_accuracy)}")
        # print(f"==> GT Diversity: {mean_confidence_interval(gt_diversity)}")
        # print(f"==> GT Multimodality: {mean_confidence_interval(gt_multimodality)}")
        # print(f"==> Gen Accuracy: {mean_confidence_interval(gen_accuracy)}")
        # print(f"==> Gen Diversity: {mean_confidence_interval(gen_diversity)}")
        # print(f"==> Gen Multimodality: {mean_confidence_interval(gen_multimodality)}")
        # print(f"==> Refined Accuracy: {mean_confidence_interval(refined_accuracy)}")
        # print(f"==> Refined Diversity: {mean_confidence_interval(refined_diversity)}")
        # print(f"==> Refined Multimodality: {mean_confidence_interval(refined_multimodality)}")
        # print(f"==> Gen FID: {mean_confidence_interval(gen_fid)}")
        # print(f"==> Refined FID: {mean_confidence_interval(refined_fid)}")


if __name__ == "__main__":
    main(args)

