from config import get_config

from collections import defaultdict
import statistics
from math import sqrt

import os
import json

import numpy as np

from scipy import stats

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# python user_study_plot.py --trajs_savepath ../overcooked_env/recordings

SCORES = defaultdict(lambda: defaultdict(lambda: []))
NAMES = defaultdict(lambda: defaultdict(lambda: []))

VALID_SCORES = {
    'cramped_room': [0, 1, 2, 3, 4],
    'coordination_ring': [0, 1, 2, 3]
}

WEIGHTED_SCORES = defaultdict(lambda: defaultdict(lambda: []))

ALGO_LIST = ['SP', 'ADAP', 'XP', 'MP']
ALGO_NAMES = ['SP', 'ADAP', 'CoMeDi$_0$', 'CoMeDi$_1$']

ALGO_COLORS = [
    mcolors.CSS4_COLORS['lightgray'],
    mcolors.CSS4_COLORS['gray'],
    mcolors.CSS4_COLORS['khaki'],
    '#ff9b00'
]

PEOPLE_LIST = []

matplotlib.rcParams['font.family'] = 'sans serif'
matplotlib.rcParams['text.color'] = 'black'

LAYOUT_DICT = {
    "cramped_room": "Cramped Room",
    "coordination_ring": "Coordination Ring"
}

USE_BEST = True

MAX_TIMESTEPS = 203


def get_stdev(values):
    return statistics.pstdev(values) / sqrt(len(values))


def insert_traj(layout, algo, timespaces, id):
    if len(timespaces) == 0:
        return
    # print(f"{algo}: {id}")
    if layout == 'cramped_room' and algo == 'MP':
        PEOPLE_LIST.append(id)
    SCORES[layout][algo] += [len(timespaces)]

    weighted_score = timespaces[-1] / len(timespaces)
    WEIGHTED_SCORES[layout][algo].append(weighted_score)
    NAMES[layout][algo].append(id)


def parse_traj(path, entrynum, layout, algo, id):
    if entrynum >= 2:
        # print(id, layout, algo)
        return []

    trajectory = None
    with open(path) as f:
        trajectory = json.load(f)

    if trajectory is None:
        return []

    rewards = trajectory['ep_rewards'][0]
    if len(rewards) == 0:
        return []
    # print(len(rewards))
    serveindex = [i for i, score in enumerate(rewards) if score == 20]

    if not USE_BEST:
        insert_traj(layout, algo, serveindex, id)

    return serveindex


def get_max(timespaces, curbest):
    if len(timespaces) > len(curbest):
        return timespaces

    if (len(timespaces) > 0 and len(timespaces) == len(curbest) and
            timespaces[-1] < curbest[-1]):
        return timespaces

    return curbest


def parse_config(path, layout, algo):
    for id in os.listdir(path):
        iddir = path + "/" + id

        if layout == 'cramped_room' and algo == 'MP':
            print(id)

        if not os.path.isdir(iddir):
            continue

        for ip in os.listdir(iddir):
            ipdir = iddir + "/" + ip
            if not os.path.isdir(ipdir):
                continue

            besttimespaces = []
            # print(f"User {id} with ip {ip}")
            for entry in os.listdir(ipdir):
                splitentry = entry.split('.')
                if splitentry[-1] == 'json':
                    # print(f"Parsing {entry}")
                    timespaces = parse_traj(
                        ipdir + "/" + entry,
                        int(splitentry[0]),
                        layout,
                        algo,
                        id
                    )
                    besttimespaces = get_max(timespaces, besttimespaces)

            if USE_BEST:
                print(algo, besttimespaces, id)
                insert_traj(layout, algo, besttimespaces, id)


def parse_files():
    if ARGS.trajs_savepath[-1] == "/":
        ARGS.trajs_savepath = ARGS.trajs_savepath[:-1]
    path = ARGS.trajs_savepath
    for layout in os.listdir(path):
        layoutdir = path + "/" + layout
        if not os.path.isdir(layoutdir):
            continue

        for algo in os.listdir(layoutdir):
            algodir = layoutdir + "/" + algo
            if not os.path.isdir(algodir):
                continue

            # print(f"PARSING {layout} with {algo}")
            parse_config(algodir, layout, algo)


def plotScores():
    for key in SCORES:
        print(f"Scores for {LAYOUT_DICT[key]}")
        algo_scores = []
        algo_std = []
        for algo in ALGO_LIST:
            algo_scores.append(sum(SCORES[key][algo])/len(SCORES[key][algo]))
            algo_std.append(get_stdev(SCORES[key][algo]))
            print(f"{algo}: {algo_scores[-1]} ({algo_std[-1]})")

        plt.clf()
        # plt.title(f"{LAYOUT_DICT[key]}: Average Scores", fontsize=20)
        plt.ylabel("Dishes Served", color="black", fontsize=25)
        plt.bar(ALGO_NAMES, algo_scores, yerr=algo_std, color=ALGO_COLORS)

        for pos in ['right', 'top']:
            plt.gca().spines[pos].set_visible(False)
        plt.gca().spines['bottom'].set_color('black')
        plt.gca().spines['left'].set_color('black')

        plt.gca().tick_params(axis='x', colors='black')
        plt.gca().tick_params(axis='y', colors='black')

        plt.yticks(VALID_SCORES[key], VALID_SCORES[key], fontsize=25)
        plt.xticks(fontsize=20)
        plt.tight_layout()

        plt.savefig(f"{key}_scores.pdf")
        plt.show()


def plotWeightedScores():
    for key in SCORES:
        print(f"Average Timesteps for {LAYOUT_DICT[key]}")
        algo_scores = []
        algo_std = []
        for algo in ALGO_LIST:
            m = sum(WEIGHTED_SCORES[key][algo])/len(WEIGHTED_SCORES[key][algo])
            algo_scores.append(m)
            algo_std.append(get_stdev(WEIGHTED_SCORES[key][algo]))
            print(f"{algo}: {algo_scores[-1]} ({algo_std[-1]})")

        plt.clf()
        # plt.title(f"{LAYOUT_DICT[key]}: Average Delivery Time",
                  # fontsize=20)
        plt.xlabel("Timesteps", color="black", fontsize=25)
        plt.barh(ALGO_NAMES, algo_scores, xerr=algo_std, color=ALGO_COLORS)
        plt.gca().invert_yaxis()

        for pos in ['right', 'top']:
            plt.gca().spines[pos].set_visible(False)
        plt.gca().spines['bottom'].set_color('black')
        plt.gca().spines['left'].set_color('black')

        plt.gca().tick_params(axis='x', colors='black')
        plt.gca().tick_params(axis='y', colors='black')

        plt.yticks(fontsize=25)
        plt.xticks(fontsize=25)

        plt.tight_layout()


        plt.savefig(f"{key}_times.pdf")
        plt.show()


def printTTestScores():
    for layout in SCORES:
        print(layout)

        MP = np.array(SCORES[layout]['MP'])
        ADAP = np.array(SCORES[layout]['ADAP'])
        XP = np.array(SCORES[layout]['XP'])
        SP = np.array(SCORES[layout]['SP'])

        # print(PEOPLE_LIST)
        print(MP)
        print(ADAP)
        print(XP)
        print(SP)

        print("MP > ADAP:", stats.ttest_rel(MP, ADAP, alternative='greater'))
        print("MP > XP:", stats.ttest_rel(MP, XP, alternative='greater'))
        print("MP > SP:", stats.ttest_rel(MP, SP, alternative='greater'))
        print()
        print("XP > ADAP:", stats.ttest_rel(XP, ADAP, alternative='greater'))
        print("XP > MP:", stats.ttest_rel(XP, MP, alternative='greater'))
        print("XP > SP:", stats.ttest_rel(XP, SP, alternative='greater'))
        print('\n')


def printTTestTimes():
    for layout in SCORES:
        print(layout, " Times")
        MP = np.array(WEIGHTED_SCORES[layout]['MP'])
        ADAP = np.array(WEIGHTED_SCORES[layout]['ADAP'])
        XP = np.array(WEIGHTED_SCORES[layout]['XP'])
        SP = np.array(WEIGHTED_SCORES[layout]['SP'])

        print(np.mean(MP))

        # print(MP)
        # print(ADAP)
        # print(XP)
        # print(SP)

        print("MP < ADAP:", stats.ttest_rel(MP, ADAP, alternative='less'))
        print("MP < XP:", stats.ttest_rel(MP, XP, alternative='less'))
        print("MP < SP:", stats.ttest_rel(MP, SP, alternative='less'))
        print()
        print("XP < ADAP:", stats.ttest_rel(XP, ADAP, alternative='less'))
        print("XP < MP:", stats.ttest_rel(XP, MP, alternative='less'))
        print("XP < SP:", stats.ttest_rel(XP, SP, alternative='less'))
        print('\n')


def main():
    parse_files()
    # print(SCORES)
    # print(FIRST_DELIVERY)

    plotScores()
    # plotFirstTimes()
    plotWeightedScores()

    printTTestScores()

    printTTestTimes()


if __name__ == '__main__':
    parser = get_config()
    parser.add_argument('--trajs_savepath', type=str,
                        help="folder to save trajectories")
    ARGS = parser.parse_args()
    main()
