import os
import json
import matplotlib.pyplot as plt
from typing import Tuple, List

COLORS = ['red', 'green', 'yellow', 'blue', 'purple', 'black']


def internal_plot_scales(xs, list_ys) -> Tuple[List[float], List[float]]:
    assert len(xs) == len(list_ys)
    xs_, ys_ = [], []
    for i in range(len(xs)):
        list_y = list_ys[i]
        x = xs[i]
        if i < len(xs) - 1:
            xp1 = xs[i + 1]
        elif i == 0:
            xp1 = x + x
        else:
            xp1 = x + x - xs[i - 1]
        n_y = len(list_y)
        if n_y == 0:
            continue
        for j, y in enumerate(list_y):
            xs_.append(x + (xp1 - x) * ((j + 0.5) / n_y - 1))
            ys_.append(y)
    return xs_, ys_


def is_lower_better(loss_type: str) -> bool:
    if loss_type.startswith('adj'):
        return True
    if loss_type in ['d', 'phi', 'psi']:
        return True
    if loss_type in ['rmsd', 'mat']:
        return True
    if loss_type == 'lddt-score':
        return False
    if loss_type.startswith('cov'):
        return False
    assert False, f'Undefined loss_type {loss_type}, please modify this function.'


def tendency_c(log: list, path: str, higher_is_better=False):
    seq_dict = {}
    for k in log[0].keys():
        seq_dict[k] = [dic[k] for dic in log]

    try:
        epoch_ids = [eid * len(log[-1]['on_training_losses']) for eid in seq_dict['epoch']]
    except KeyError:
        epoch_ids = seq_dict['epoch']
    loss_types = []
    for k in seq_dict.keys():
        if k.startswith('test') and k.endswith('loss'):
            loss_types.append(k[5: -5])
    loss_types = sorted(loss_types)

    fig = plt.figure()
    print(f'{path} @ epoch {epoch_ids[-1]}')
    if path.endswith('rdkit.png') or 'ff' in path:
        for i, loss_type in enumerate(loss_types):
            factor = 1.
            if loss_type == 'adj3':
                continue
            test_c = [dic[f'test_{loss_type}_loss'] / factor for dic in log]
            print(f'\t{loss_type}: {test_c[-1]:.3f}')
        return

    for i, loss_type in enumerate(loss_types):
        factor = 1.
        if loss_type == 'adj3':
            continue
        # if loss_type == 'equiv-trunc':
        #     factor = 100.
        valid_c = [dic[f'validate_{loss_type}_loss'] / factor for dic in log]
        try:
            ref_c = [dic[f'train_{loss_type}_loss'] / factor for dic in log]
        except KeyError:
            ref_c = valid_c
        test_c = [dic[f'test_{loss_type}_loss'] / factor for dic in log]
        ps = list(zip(valid_c, test_c))
        if epoch_ids[0] == 0 and len(ps) > 1:
            ps = ps[1:]
        ps = sorted(ps, key=lambda x: x[0], reverse=is_lower_better(loss_type))
        print(f'\t{loss_type}: {ps[-1][1]:.3f}')

        plt.plot(epoch_ids, ref_c, color=COLORS[i % len(COLORS)], linestyle='--')
        plt.plot(epoch_ids, test_c, color=COLORS[i % len(COLORS)], label=loss_type)
        plt.legend(loc='upper left')

    try:
        ax = plt.twinx()
        training_losses = [dic['on_training_losses'] for dic in log]
        epoch_ids_ = epoch_ids
        if epoch_ids_[0] == 0:
            training_losses = training_losses[1:]
            epoch_ids_ = epoch_ids_[1:]
        xs, ys = internal_plot_scales(epoch_ids_, training_losses)
        ax.plot(xs, ys, color='grey', label='train loss')
        ax.legend(loc='upper right')
    except KeyError:
        print(f'\t# No training loss in {path}')

    plt.legend()
    plt.savefig(path)
    plt.close(fig)


tuples = [
    ('qm9-conf', 'rdkit', False),
    ('qm9-conf', 'rdkit-newton-equiv-trunc', False),
    ('qm9-conf', 'rdkit-newton-equiv-trunc_nof', False),
    ('qm9-conf', 'rdkit-newton-equiv-trunc_nos', False),
    ('qm9-conf', 'rdkit-newton-adj3', False),
    ('qm9-conf', 'rdkit-newton-kabsch', False),
    ('qm9-conf', 'rdkit-newton-naive', False),
    ('qm9-conf', 'rdkit-newton-lddt5', False),
    ('qm9-conf', 'rdkit-newton-adj-1', False),

    # ('geom_qm9-conf', 'rdkit', False),
    # ('geom_qm9-conf', 'rdkit-newton-equiv-trunc', False),
    # ('geom_qm9-conf', 'rdkit-newton-adj3', False),
    # ('geom_qm9-conf', 'rdkit-newton-kabsch', False),
    # ('geom_qm9-conf', 'rdkit-newton-naive', False),
    # ('geom_qm9-conf', 'cvgae--equiv-trunc', False),
    # ('geom_qm9-conf', 'cvgae--adj3', False),
    # ('geom_qm9-conf', 'cvgae--kabsch', False),
    # ('geom_qm9-conf', 'cvgae--naive', False),
    # ('geom_qm9-ff', 'rdkit-newton-equiv-trunc', False),
    # ('geom_qm9-ff', 'rdkit-newton-adj3', False),
    # ('geom_qm9-ff', 'rdkit-newton-kabsch', False),
    # ('geom_qm9-ff', 'rdkit-newton-naive', False),
    # ('geom_qm9-ff', 'cvgae--equiv-trunc', False),
    # ('geom_qm9-ff', 'cvgae--adj3', False),
    # ('geom_qm9-ff', 'cvgae--kabsch', False),
    # ('geom_qm9-ff', 'cvgae--naive', False),
    #
    # ('geom_drugs-conf', 'rdkit', False),
    # ('geom_drugs-conf', 'rdkit-newton-equiv-trunc', False),
    # ('geom_drugs-conf', 'rdkit-newton-adj3', False),
    # ('geom_drugs-conf', 'rdkit-newton-kabsch', False),
    # ('geom_drugs-conf', 'rdkit-newton-naive', False),
    # ('geom_drugs-conf', 'cvgae--equiv-trunc', False),
    # ('geom_drugs-conf', 'cvgae--adj3', False),
    # ('geom_drugs-conf', 'cvgae--kabsch', False),
    # ('geom_drugs-conf', 'cvgae--naive', False),
    # ('geom_drugs-ff', 'rdkit-newton-equiv-trunc', False),
    # ('geom_drugs-ff', 'rdkit-newton-adj3', False),
    # ('geom_drugs-ff', 'rdkit-newton-kabsch', False),
    # ('geom_drugs-ff', 'rdkit-newton-naive', False),
    # ('geom_drugs-ff', 'cvgae--equiv-trunc', False),
    # ('geom_drugs-ff', 'cvgae--adj3', False),
    # ('geom_drugs-ff', 'cvgae--kabsch', False),
    # ('geom_drugs-ff', 'cvgae--naive', False),

    # ('qm7-conf', 'rdkit-newton-equiv-trunc', False),
    # ('qm7-conf', 'rdkit-newton-adj3', False),
    # ('geom_qm9-small-conf', 'rdkit', False),
    # ('geom_qm9-small-conf', 'rdkit-newton-equiv-trunc', False),
    # ('geom_qm9-small-conf', 'rdkit-newton-adj3', False),
    # ('geom_qm9-small-conf', 'cvgae-newton-equiv-trunc', False),
    # ('geom_qm9-small-conf', 'cvgae--equiv-trunc', False),
    # ('geom_qm9-small-conf', 'cvgae--adj3', False),
    # ('geom_qm9-small-conf', 'cvgae--kabsch', False),
    # ('geom_qm9-small-ff', 'rdkit-newton-equiv-trunc', False),
    # ('geom_qm9-small-ff', 'rdkit-newton-adj3', False),
    # ('geom_qm9-small-ff', 'cvgae--equiv-trunc', False),
    # ('geom_qm9-small-ff', 'cvgae--adj3', False),
    # ('geom_qm9-small-ff', 'cvgae--kabsch', False),
    # ('geom_drugs-small-conf', 'rdkit', False),
    # ('geom_drugs-small-conf', 'rdkit-newton-equiv-trunc', False),
    # ('geom_drugs-small-conf', 'rdkit-newton-adj3', False),
    # ('geom_drugs-small-ff', 'rdkit-newton-equiv-trunc', False),
    # ('geom_drugs-small-ff', 'rdkit-newton-adj3', False),
]

if __name__ == '__main__':
    for d, f, h in tuples:
        if not os.path.exists(d):
            os.mkdir(d)
        json_path = f'{d}/{f}.json'
        graph_path = f'{d}/{f}.png'
        try:
            with open(json_path) as fp:
                log = json.load(fp)
        except FileNotFoundError:
            continue
        tendency_c(log, graph_path, h)
