import os
import json
import torch
import numpy as np
import argparse
from tqdm import tqdm
from collections import Counter
from sentence_transformers import SentenceTransformer


def get_data(base_dir):
    all_chunk_data = []
    chunk_dir = os.path.join(base_dir, "chunks")
    num_chunks = len(os.listdir(chunk_dir))
    for chunk_idx in range(num_chunks):
        file_name = os.path.join(chunk_dir, f"chunk_{chunk_idx}.json")
        with open(file_name, "r") as f:
            all_chunk_data += json.load(f)
    return all_chunk_data


def get_statistics(data, idx_list, goal):
    data = [data[idx]['tag'][goal] if data[idx]['tag'] is not None else "None" for idx in idx_list]
    counts = Counter(data)
    counts = dict(counts.most_common())
    return counts


def use_sbert(args, data_1, data_2, output_dir):
    model = SentenceTransformer("paraphrase-MiniLM-L6-v2")

    def get_sim(goal):
        tag_1, tag_2, non_indices = [], [], []
        for idx, (item_1, item_2) in enumerate(zip(data_1, data_2)):
            if not item_1['tag'] or not item_2['tag']:
                non_indices.append(idx)
                tag_1.append("None")
                tag_2.append("None")
            else:
                tag_1.append(item_1['tag'][goal])
                tag_2.append(item_2['tag'][goal])
        embeds_1 = model.encode(tag_1)
        embeds_2 = model.encode(tag_2)
        embeds_1 = embeds_1 / np.linalg.norm(embeds_1, axis=1, keepdims=True)
        embeds_2 = embeds_2 / np.linalg.norm(embeds_2, axis=1, keepdims=True)
        sim = np.sum(embeds_1 * embeds_2, axis=1)
        sim[non_indices] = 0
        sim = [1.0 if x >= 1 or str(x).startswith("0.999999") else round(x, 6) for x in sim.tolist()]
        return sim

    sim_dir = os.path.join(output_dir, "tag_sim")
    sim_file_ord = os.path.join(sim_dir, f"({args.method_1})({args.method_2}).json")
    sim_file_rev = os.path.join(sim_dir, f"({args.method_2})({args.method_1}).json")
    if not os.path.exists(sim_file_ord) and not os.path.exists(sim_file_rev):
        if not os.path.exists(sim_dir):
            os.makedirs(sim_dir)
        domain_sim = get_sim('Domain')
        type_sim = get_sim('Type')
        sim_data = {'Domain': domain_sim, 'Type': type_sim}
        with open(sim_file_ord, "w") as f:
            json.dump(sim_data, f, indent=4)
    else:
        sim_file = sim_file_ord if os.path.exists(sim_file_ord) else sim_file_rev
        with open(sim_file, "r") as f:
            sim_data = json.load(f)
        domain_sim = sim_data['Domain']
        type_sim = sim_data['Type']

    sum_sim = [d + t for d, t in zip(domain_sim, type_sim)]
    both_idx_list = sorted(np.argsort(sum_sim)[::-1][:10000].tolist())
    return both_idx_list


def run(args):
    output_dir = os.path.join(os.getcwd(), f"../../output/tagging/analysis")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    data_1 = get_data(os.path.join(os.getcwd(), "../../output/tagging", args.method_1))
    data_2 = get_data(os.path.join(os.getcwd(), "../../output/tagging", args.method_2))

    both_idx_list = use_sbert(args, data_1, data_2, output_dir)
    both_domain_counts = get_statistics(data_1, both_idx_list, 'Domain')
    both_type_counts = get_statistics(data_1, both_idx_list, 'Type')

    statistics_data = {
        "Domain Distribution": both_domain_counts,
        "Type Distribution": both_type_counts,
        "Number of Diagrams": len(both_idx_list),
        "List of Diagrams": str(both_idx_list),
    }
    statistics_file = os.path.join(output_dir, f"tag_sbert_({args.method_1})({args.method_2}).json")
    with open(statistics_file, "w") as f:
        json.dump(statistics_data, f, indent=4)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--method_1', type=str)
    parser.add_argument('--method_2', type=str)
    args = parser.parse_args()
    run(args)

