import argparse
import pickle
import numpy as np
from theorem_expansion import *
from collections import Counter
import pandas as pd
import json
import torch
import matplotlib.pyplot as plt
import copy


def get_proof_level_acc(node_correctness, batch_batch):
    proof_level_acc = torch.zeros((batch_batch[-1].item() + 1,)).to(batch_batch.device)
    for i in range(batch_batch[-1].item() + 1):
        current_correctness = node_correctness[batch_batch == i]
        count = current_correctness.long().sum().item()
        if count != current_correctness.shape[0]:
            proof_level_acc[i] = 0
        else:
            proof_level_acc[i] = 1
    return proof_level_acc


def analyze_node_level_accuracy(y_hat, y, batch_batch):
    y_hat_hard = y_hat.round()
    node_correctness = (y == y_hat.round())
    for i in range(batch_batch[-1].item() + 1):
        current_correctness = node_correctness[batch_batch == i]
        y_hat_hard_current = y_hat_hard[batch_batch == i]
        y_current = y[batch_batch == i]
        print('node accuracy: {0}'.format(current_correctness.float().mean()))
        print('percentage of predicted red nodes over ground truth: {0}'.format(y_hat_hard_current.sum() / y_current.sum()))


def evaluate_loader(loader, model):
    predictions = []
    labels = []
    for batch in loader:
        batch = batch.to(torch.device('cuda'))
        y_hat = model(batch)
        predictions.extend(y_hat.tolist())
        labels.extend(batch.y.tolist())
    return predictions, labels


def get_correct_total_stat(correct_proof_names, dict_by_expanding_theorem):
    d = {}
    for correct_proof_name in correct_proof_names:
        variant = int(correct_proof_name.split('_')[-1])
        name = correct_proof_name[:correct_proof_name.find('variant') - 1]
        expanding_theorem = name[name.find('expand_') + 7:name.find('_in_')]
        if expanding_theorem not in d:
            d[expanding_theorem] = [0, 0]  # correct, total
        d[expanding_theorem][0] += 1
    for theorem in dict_by_expanding_theorem.keys():
        if theorem not in d.keys():
            d[theorem] = [0, 0]
        d[theorem][1] += len(dict_by_expanding_theorem[theorem])
    return d


def plot_histogram(data, bins=[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]):
    plt.hist(data, bins)
    plt.show()
    temp = np.array(data)
    print(((temp >= 0.1) & (temp <= 0.9)).sum() / len(temp))


def change_proof_name(proofs, suffix):
    for k in proofs.keys():
        if 'expand' in k:
            v_list = proofs[k]
            for i in range(len(v_list)):
                v_list[i].name += '_variant_{0}'.format(i) + suffix
        else:
            proofs[k].name += suffix


def check_proof_correct(predictions, labels):
    assert len(predictions) == len(labels)
    predictions = np.round(np.array(predictions))
    num_node_correct = np.sum(predictions == labels)
    if num_node_correct != len(predictions):
        return 0
    else:
        return 1


def check_proof_is_tree(proof_raw, predictions):
    # must have more than one colored node
    adjacency_dict = {}
    new_source = proof_raw[2]
    new_target = proof_raw[1]
    nodes = proof_raw[3]
    assert len(nodes) == len(predictions)
    # check if the proof is a tree
    for i in range(len(predictions)):
        if round(predictions[i]) == 1:
            adjacency_dict[i] = []
    # check if it only has one node
    if len(adjacency_dict) <= 1:
        return 0
    for i in range(len(new_source)):
        if new_source[i] in adjacency_dict.keys() and new_target[i] in adjacency_dict.keys():
            adjacency_dict[new_source[i]].append(new_target[i])
    all_nodes = list(adjacency_dict.keys())
    nodes_with_incoming_edges = []
    for k, v in adjacency_dict.items():
        nodes_with_incoming_edges.extend(v)
    assert len(set(nodes_with_incoming_edges)) == len(nodes_with_incoming_edges)
    if len(all_nodes) - len(nodes_with_incoming_edges) != 1:
        return 0
    else:
        return 1


def find_root_node_proof_tree(proof):
    if round(proof.subst) == 1:
        return proof
    # if len(proof.mand_vars) == 0 and len(proof.hps) == 0:
    #     return None
    for child in proof.mand_vars:
        res = find_root_node_proof_tree(child)
        if res is not None:
            return res
    for child in proof.hps:
        res = find_root_node_proof_tree(child)
        if res is not None:
            return res
    return None


def check_proof_meaningful(mm, proof, extracted_proof_name):
    # do this only if proof is already a tree
    root_node = find_root_node_proof_tree(proof)

    # another dfs here, for each red node, its children must have the same color
    a = [root_node]
    while len(a) > 0:
        node = a.pop(0)
        temp = []
        flag = None
        for child in node.mand_vars:
            if flag is None:
                flag = round(child.subst)
            else:
                if round(child.subst) != flag:
                    return None
            temp.append(child)
        for child in node.hps:
            if flag is None:
                flag = round(child.subst)
            else:
                if round(child.subst) != flag:
                    return None
            temp.append(child)
        a = temp + a
    # now we can extract it safely
    extracted_proof = extract_potential_meaningful_proof(proof)

    # provide a name
    extracted_proof.name = extracted_proof_name

    standardized_extracted_proof = standardize(mm, extracted_proof)

    return standardized_extracted_proof


def extract_potential_meaningful_proof(proof):
    # do this only the proof is a tree, and for each red node, all its children have the same color
    root_node = find_root_node_proof_tree(proof)
    new_root_node = copy.deepcopy(root_node)

    # another dfs here
    a = [new_root_node]
    while len(a) > 0:
        node = a.pop(0)
        temp = []
        for child in node.mand_vars:
            if round(child.subst) == 1:
                temp.append(child)
            else:
                node.mand_vars = []
                break
        for child in node.hps:
            if round(child.subst) == 1:
                temp.append(child)
            else:
                node.hps = []
                break
        a = temp + a

    return new_root_node


def standardize(mm, extracted_proof):
    leaves = extracted_proof.get_leaves(change_type=True)
    replace_dict = {}
    used_mand_vars = []
    hps_counter = 0
    labels = mm.labels
    for i in range(len(leaves)):
        leaf = leaves[i]
        if leaf.type == '$e':
            pass
            # hps_counter += 1
            # leaf.label = extracted_proof.name + '.{0}'.format(hps_counter)
            # labels[leaf.label] = ('$e', leaf.expr)  # just a placeholder for the expr, don't use copy here since it will destroy the automatic substitution in propagate
        elif leaf.type == '$f':
            # assert len(leaf.expr) == 2  # not true
            if tuple(leaf.expr) not in replace_dict:
                for k, v in labels.items():
                    if v[0] == '$f' and v[1][0] == leaf.expr[0] and k not in used_mand_vars:
                        assert len(v[1]) == 2
                        replace_dict[tuple(leaf.expr)] = k
                        used_mand_vars.append(k)
                        break
                if tuple(leaf.expr) not in replace_dict:
                    # used up our alphabet
                    print('used up our alphabet')
                    return None
            leaf.label = replace_dict[tuple(leaf.expr)]
            leaf.expr = copy.deepcopy(labels[replace_dict[tuple(leaf.expr)]][1])
            leaf.data = leaf.expr
        else:
            raise NotImplementedError
    proof_list = extracted_proof.summarize_proof()
    return extracted_proof
    # standardized_extracted_proof = mm.propagate_and_substitute_leaf_hps(proof_list, extracted_proof.name)
    # success, _ = mm.verify_custom(standardized_extracted_proof.expr, standardized_extracted_proof.summarize_proof(), '', mode='other')
    # if success:
    #     print('verified {0}'.format(extracted_proof.name))
    #     return standardized_extracted_proof
    # else:
    #     print('still cannot verify {0}'.format(extracted_proof.name))
    #     return None


def color_proof_tree(proof, predictions):
    # custom dfs
    visited = []
    a = [proof]
    while len(a) > 0:
        node = a.pop(0)
        visited.append(node)
        temp = []
        for child in node.mand_vars:
            temp.append(child)
        for child in node.hps:
            temp.append(child)
        a = temp + a

    assert len(visited) == len(predictions)
    for i in range(len(visited)):
        visited[i].subst = predictions[i]


def analyze_predictions(predictions, labels, mm, raw_dataset, is_expanded):
    # need prediction, label list, mm_proof format, as well as the raw dataset
    mm_proof_labels = mm.proofs
    mm_proof_predictions = copy.deepcopy(mm_proof_labels)  # copy for predictions

    change_proof_name(mm_proof_labels, '_label')
    change_proof_name(mm_proof_predictions, '_prediction')
    num_is_tree = 0
    num_correct = 0
    num_color_one_or_less = 0
    num_color_all = 0
    num_meaningful = 0  # do not count the correct ones
    counter = 0
    for proof_raw in raw_dataset:
        name = proof_raw[0]
        if '2false_in_bianfi' in name:
            print()
        proof_length = len(proof_raw[3])
        if is_expanded:
            variant = int(name.split('_')[-1])
            proof_name = name[:name.find('variant') - 1]
            proof_label = mm_proof_labels[proof_name][variant]
            proof_prediction = mm_proof_predictions[proof_name][variant]
        else:
            proof_label = mm_proof_labels[name]
            proof_prediction = mm_proof_predictions[name]
        current_predictions = predictions[counter:counter + proof_length]
        current_labels = labels[counter:counter + proof_length]

        num_colored_nodes = np.sum(np.round(current_predictions))
        color_one_or_less = int(num_colored_nodes <= 1)
        num_color_one_or_less += color_one_or_less
        color_all = int(num_colored_nodes == proof_length)
        num_color_all += color_all

        color_proof_tree(proof_prediction, current_predictions)
        correct = check_proof_correct(current_predictions, current_labels)
        is_tree = check_proof_is_tree(proof_raw, current_predictions)

        if correct == 1 and is_tree != 1 and not color_one_or_less:
            raise NotImplementedError('if correct, should definitely be a tree')
        # do extraction
        if is_tree == 1 and not color_one_or_less:
            if is_expanded:
                extracted_proof_name = name.replace('expand_', 'extracted_')
                if not correct:
                    extracted_proof_name = 'new_theorem_{0}_from_'.format(num_meaningful) + extracted_proof_name
            else:
                extracted_proof_name = 'new_theorem_{0}_from_'.format(num_meaningful) + name
            meaningful_proof = check_proof_meaningful(mm, proof_prediction, extracted_proof_name)
            if correct == 1 and meaningful_proof is None:
                raise NotImplementedError('correct proof should definitely be meaningful')
            if meaningful_proof is not None and correct == 0:
                num_meaningful += 1
                # meaningful_proof.draw_graph_3()
                # proof_prediction.draw_graph_3()
                # proof_label.draw_graph_3()
        num_correct += correct
        num_is_tree += is_tree
        # proof_prediction.draw_graph_3()
        # proof_label.draw_graph_3()
        counter += proof_length
    print('num correct: {0}'.format(num_correct))
    print('num color one or less: {0}'.format(num_color_one_or_less))
    print('num color all: {0}'.format(num_color_all))
    print('num meaningful but not correct: {0}'.format(num_meaningful))
    print('num meaningful: {0}'.format(num_meaningful + num_correct))
    print('num is_tree: {0}'.format(num_is_tree))
    print('num total: {0}'.format(len(raw_dataset)))


def main(args):
    valid_predictions = np.load(args.path + 'valid_predictions.npy')
    valid_labels = np.load(args.path + 'valid_labels.npy')
    with open(args.data_path + 'valid_dataset.pkl', 'rb') as f:
        valid_dataset_raw = pickle.load(f)
    with open(args.data_path + '{0}_verified_expanded.pkl'.format(args.main_file), 'rb') as f:
        mm = pickle.load(f)
    analyze_predictions(valid_predictions, valid_labels, mm, valid_dataset_raw, True)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="theorem verification")
    parser.add_argument('-path', dest='path', type=str, default='experiment/')
    parser.add_argument('-data_path', dest='data_path', type=str, default='dataset/propositional_mm_split_by_theorem/')
    parser.add_argument('-main_file', dest='main_file', type=str, default='propositional.mm')
    args = parser.parse_args()
    main(args)
