import os
import json
import argparse
import pandas as pd
from collections import Counter


def select_triple():
    with open("evaluate/raw/triple.json", "r") as f:
        triple = json.load(f)

    valid_triple = []
    invalid_reason = []
    for item in triple:
        if any(t['Source'] == t['Target'] for t in item['triple']):
            invalid_reason.append({item['idx']: "Filtered: Self Loop"})
            continue
        if len(item['triple']) < 3 or len(item['triple']) > 30:
            invalid_reason.append({item['idx']: "Filtered: Triple Number"})
            continue
        # if any("triple" in t['Source'].lower() or "triple" in t['Target'].lower() for t in item['triple']):
        #     invalid_reason.append({item['idx']: "Filtered: Include 'Triple'"})
        #     continue
        # if (max(Counter(t["Source"] for t in item['triple']).values()) > len(item['triple']) // 2 or
        #     max(Counter(t["Target"] for t in item['triple']).values()) > len(item['triple']) // 2):
        #     invalid_reason.append({item['idx']: "Filtered: Overuse Node"})
        #     continue
        valid_triple.append(item["idx"])

    return valid_triple, invalid_reason


def select_sol(args):
    def get_df(file):
        df = pd.read_csv(file)
        df.columns = ['idx', 0, 1, 2, 3]
        return df

    df_empty = get_df(f"evaluate/comparison/cmp_{args.sol_name_empty}.csv")
    df_wiki = get_df(f"evaluate/comparison/cmp_{args.sol_name_wiki}.csv")
    df_triple = get_df(f"evaluate/comparison/cmp_{args.sol_name_triple}.csv")

    valid_empty = df_empty[df_empty[[0, 1]].eq(0).any(axis=1)]['idx'].tolist()
    valid_wiki = df_wiki[df_wiki[[0, 1]].eq(0).any(axis=1)]['idx'].tolist()
    valid_triple = df_triple[df_triple[[0, 1, 2, 3]].eq(1).all(axis=1)]['idx'].tolist()

    valid_sol = sorted(list(set(valid_empty) & set(valid_wiki) & set(valid_triple)))
    return valid_sol


def select(args):
    valid_triple, invalid_reason = select_triple()
    valid_sol = select_sol(args)
    valid = sorted(list(set(valid_triple) & set(valid_sol)))

    selection_dir = os.path.join(os.getcwd(), f"evaluate/selection")
    if not os.path.exists(selection_dir):
        os.makedirs(selection_dir)

    with open("intersect/intersect_idx.json", "r") as f:
        intersect_idx_dict = json.load(f)
    new_categories = {
        "Bar Chart": {"test": 150, "train": 900},
        "Line Graph": {"test": 150, "train": 350},
        "Map": {"test": 150, "train": 2000},
        "Pie Chart": {"test": 150, "train": 0},
        "Biology": {"test": 150, "train": 900},
        "Chemistry": {"test": 150, "train": 1600},
        "Computer Science": {"test": 150, "train": 0},
        "Mathematics": {"test": 150, "train": 150},
        "Physics": {"test": 150, "train": 100},
        "Others": {"test": 150, "train": 0}
    }
    select_stat = {key: 0 for key in new_categories.keys()}
    global_idx = {key: [] for key in new_categories.keys()}
    for idx in valid:
        cate = intersect_idx_dict[str(idx)]
        if cate in ["Astronomy", "History", "Music"]:
            cate = "Others"
        select_stat[cate] += 1
        global_idx[cate].append(idx)
    global_idx_test, global_idx_train = {}, {}
    for new_cate, num in new_categories.items():
        global_idx_test[new_cate] = global_idx[new_cate][:num['test']]
        global_idx_train[new_cate] = global_idx[new_cate][num['test']:num['test'] + num['train']]

    select_stat_file = os.path.join(selection_dir, f"select_stat.json")
    with open(select_stat_file, "w") as f:
        json.dump(select_stat, f, indent=4)
    global_idx_test_file = os.path.join(selection_dir, f"global_idx_test.json")
    with open(global_idx_test_file, "w") as f:
        json.dump(global_idx_test, f, indent=4)
    global_idx_train_file = os.path.join(selection_dir, f"global_idx_train.json")
    with open(global_idx_train_file, "w") as f:
        json.dump(global_idx_train, f, indent=4)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--sol_name_empty', type=str, default='empty_v2')
    parser.add_argument('--sol_name_wiki', type=str, default='wiki_v2')
    parser.add_argument('--sol_name_triple', type=str, default='triple_v2')
    parser.add_argument('--round', type=int, default=2)
    args = parser.parse_args()

    select(args)

