import argparse
import os
from pickle import dump

import numpy as np
import torch
from tqdm import tqdm


def read_file(filename):
    box = []
    flag = False
    with open(filename, "r") as f:
        for row in f:
            row = row.split()
            if row[0] == "NODE_COORD_SECTION":
                flag = True
                continue
            if flag:
                if row[0] == "EOF":
                    continue
                row = list(map(float, row))
                box.append((row[1], row[2]))
    box = np.array(box)
    box[:, 0] = (box[:, 0] - min(box[:, 0])) / (max(box[:, 0]) - min(box[:, 0]))
    box[:, 1] = (box[:, 1] - min(box[:, 1])) / (max(box[:, 1]) - min(box[:, 1]))
    transformation = np.array([[0, 1], [1, 0]])
    box = np.dot(box, transformation)

    return box


def main(args: argparse.Namespace) -> None:
    if args.mode == "train":
        np.random.seed(0)
    elif args.mode == "validation":
        np.random.seed(1)
    else:
        np.random.seed(2)

    if not args.random:
        output_folder = f"{args.output_folder}/{os.path.basename(args.filepath)}_{args.n_nodes}node_{args.n_data}num"
        locations = read_file(args.filepath)
        data = []
        idx = [i for i in range(len(locations))]
        for _ in tqdm(range(args.n_data)):
            sampled_idxs = np.random.choice(idx, args.n_nodes, replace=False)
            sampled_locations = locations[sampled_idxs]
            data.append(sampled_locations)
    else:
        output_folder = f"{args.output_folder}/random_{args.n_nodes}node_{args.n_data}num"
        data = torch.rand(size=[args.n_data, args.n_nodes, 2])
    os.makedirs(output_folder, exist_ok=True)
    data = np.array(data)
    print(output_folder)

    with open(f"{output_folder}/{args.mode}_location.pickle", "wb") as f:
        dump(data, f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--filepath", type=str, required=True, help="number of agents")
    parser.add_argument("--output_folder", type=str, default="map_data", help="number of agents")
    parser.add_argument("--n-nodes", type=int, default=100, help="number of nodes")
    parser.add_argument("--n-data", type=int, default=30, help="batchsize")
    parser.add_argument("--mode", choices=["train", "validation", "test"], required=True, help="choice")
    parser.add_argument("--random", action="store_true", help="")

    args = parser.parse_args()
    main(args)
