import argparse
import os
from pickle import dump

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm


def main(args: argparse.Namespace) -> None:
    if args.mode == "train":
        np.random.seed(0)
    elif args.mode == "validation":
        np.random.seed(1)
    else:
        np.random.seed(2)

    if not args.random:
        output_folder = f"{args.output_folder}/{os.path.basename(args.filepath)}_{args.n_nodes}node_{args.n_data}num"
        locations = pd.read_parquet(args.filepath)
        locations = np.array(locations[["x", "y"]])
        locations[:, 0] = (locations[:, 0] - min(locations[:, 0])) / (max(locations[:, 0]) - min(locations[:, 0]))
        locations[:, 1] = (locations[:, 1] - min(locations[:, 1])) / (max(locations[:, 1]) - min(locations[:, 1]))
        data = []
        generator = tqdm(range(args.n_data)) if not args.disable_tqdm else range(args.n_data)
        idx = [i for i in range(len(locations))]
        for _ in generator:
            sampled_idxs = np.random.choice(idx, args.n_nodes, replace=False)
            sampled_locations = locations[sampled_idxs]
            data.append(sampled_locations)
    else:
        output_folder = f"{args.output_folder}/random_{args.n_nodes}node_{args.n_data}num"
        data = torch.rand(size=[args.n_data, args.n_nodes, 2])
    os.makedirs(output_folder, exist_ok=True)
    data = np.array(data)
    # print(output_folder)

    with open(f"{output_folder}/{args.mode}_location.pickle", "wb") as f:
        dump(data, f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--filepath", type=str, required=True, help="number of agents")
    parser.add_argument("--output_folder", type=str, default="map_data_osmnx", help="number of agents")
    parser.add_argument("--n-nodes", type=int, default=100, help="number of nodes")
    parser.add_argument("--n-data", type=int, default=30, help="batchsize")
    parser.add_argument("--mode", choices=["train", "validation", "test"], required=True, help="choice")
    parser.add_argument("--random", action="store_true", help="")
    parser.add_argument("--disable-tqdm", action="store_true", help="")

    args = parser.parse_args()
    main(args)
