from collections import Counter

import matplotlib.pyplot as plt
import numpy as np

np.random.seed(0)
import argparse
import os
from pickle import load

from scipy.spatial import KDTree
from tqdm import tqdm


def main(args):
    # epsilonの範囲を定義
    # epsilons = np.linspace(0.05, 0.6, 12)  # 0.05から0.5までの5つの異なるepsilon
    epsilons = np.linspace(0.1, 0.8, 8)  # 0.05から0.5までの5つの異なるepsilon
    num_try = args.num_try
    filename_list = [
        "map_data_osmnx/munich,germany_14046.parquet_100node_128num/test_location.pickle",
        "map_data_osmnx/paris,france_9484.parquet_100node_128num/test_location.pickle",
        "map_data_osmnx/barcelona,spain_8914.parquet_100node_128num/test_location.pickle",
        "random",
    ]
    edge_color_list = ["blue", "orange", "green", "red"]
    all_points_list = []
    for filename in filename_list:
        if filename == "random":
            points_list = [np.random.rand(100, 2) for _ in range(num_try)]
        else:
            with open(filename, "rb") as f:
                points_list = load(f)
        all_points_list.append(points_list)

    # ヒストグラムを描くための準備

    all_ratios = []
    for points_list in all_points_list:
        part_ratios = []
        for trial in tqdm(range(num_try)):
            points = points_list[trial]
            for i in range(len(points)):
                for j in range(i + 1, len(points)):
                    dist = np.linalg.norm(points[i] - points[j])
                    part_ratios.append(dist)
        all_ratios.append(part_ratios)

    fig, axes = plt.subplots(1, len(all_points_list), figsize=(20, 4), sharey=True)
    for i in range(len(all_points_list)):
        tmp = filename_list[i]
        if tmp != "random":
            tmp = tmp.split("/")[-2].split("_")[0].split(",")[0]
        # axes[i].set_ylabel(tmp, fontsize=12, rotation=90, labelpad=10)
        # for idx in range(len(epsilons)):
        # ヒストグラムの描画
        ax = axes[i]
        print(len(all_ratios[i]))
        ax.hist(all_ratios[i], bins=args.num_bin, range=(0, 1.4), alpha=0.6, density=False)
        ax.set_title(tmp)
        # ax.set_title(f"Epsilon = {epsilons[idx]:.2f}")
    plt.tight_layout()
    output_folder = args.output_folder
    os.makedirs(output_folder, exist_ok=True)
    filename = f"{output_folder}/disthist_{args.num_try}_{args.num_bin}bin_total.png"
    # plt.ylim([0, 0.1])
    plt.savefig(filename)
    print(filename)
    plt.close()

    fig, axes = plt.subplots(1, 1, figsize=(10, 5), sharey=True)
    for i in range(len(all_points_list)):
        ax = axes
        tmp = filename_list[i]
        if tmp != "random":
            tmp = tmp.split("/")[-2].split("_")[0].split(",")[0]
        ax.hist(
            all_ratios[i],
            bins=args.num_bin,
            range=(0, 1.4),
            alpha=0.6,
            density=True,
            label=tmp,
            edgecolor=edge_color_list[i],
            linewidth=1,
        )
        ax.set_title(tmp)
    plt.legend()
    # plt.ylim([0, 0.1])
    plt.tight_layout()
    filename = f"{output_folder}/disthist_{args.num_try}_{args.num_bin}bin_total_overwrite.png"
    plt.savefig(filename)
    print(filename)
    exit()

    fig, axes = plt.subplots(2, len(epsilons) // 2, figsize=(10, 5), sharey=True)
    for idx in range(len(epsilons)):
        ax = axes[idx // (len(epsilons) // 2), idx % (len(epsilons) // 2)]
        for i in range(len(all_points_list)):
            tmp = filename_list[i]
            if tmp != "random":
                tmp = tmp.split("/")[-2].split("_")[0].split(",")[0]
            else:
                continue
            ax.hist(
                all_ratios[i][idx],
                bins=args.num_bin,
                range=(0, 99),
                alpha=0.6,
                density=True,
                label=tmp,
                edgecolor=edge_color_list[i],
                linewidth=1,
            )
        ax.set_title(f"Epsilon = {epsilons[idx]:.2f}")
    plt.legend()
    plt.ylim([0, 0.1])
    plt.tight_layout()
    filename = f"{output_folder}/{args.num_try}_{args.num_bin}bin_total_overwrite_worandom.png"
    plt.savefig(filename)
    print(filename)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--filepath", type=str, required=None, help="number of agents")
    parser.add_argument("--num-try", type=int, default=100, help="number of agents")
    parser.add_argument("--num-bin", type=int, default=10, help="number of agents")
    parser.add_argument("--random", action="store_true")
    parser.add_argument("--output-folder", type=str, default="output_density", help="number of agents")

    args = parser.parse_args()
    main(args)

    # for num_bin in [10, 20, 25, 50, 100]:
    #     args_dict = {
    #         "num_try": args.num_try,
    #         "num_bin": num_bin,
    #         "output_folder": args.output_folder,
    #     }
    #     tmp_args = argparse.Namespace(**args_dict)
    #     main(tmp_args)
