from pathlib import Path

import numpy as np
import torch
from torch import Tensor
from torch.utils.data import Dataset


class TSPLIBDataset(Dataset):
    def __init__(self, name: str, cities: list[Tensor], values: Tensor):
        self.name = name
        self.cities = cities
        self.values = values

    def __len__(self) -> int:
        return len(self.cities)

    def __getitem__(self, i: int) -> tuple[Tensor, Tensor]:
        x = self.cities[i]
        v = self.values[i]

        # Rescale to the unit square.
        min_x = torch.min(x, dim=0, keepdim=True).values
        max_x = torch.max(x, dim=0, keepdim=True).values
        scale = torch.max(max_x - min_x)

        x = x - min_x  # New min is [0, 0].
        x = x / scale
        v = v / scale

        return x, v

    @classmethod
    def from_dir(
        cls, directory: Path, min_cities: int | None = None, max_cities: int | None = None
    ) -> "TSPLIBDataset":
        cities, values = [], []
        for filepath in directory.glob("*.npz"):
            data = np.load(filepath)

            c = data["coords"][0]
            v = data["tour_lens"][0]

            if (min_cities is not None and c.shape[0] < min_cities) or (
                max_cities is not None and c.shape[0] > max_cities
            ):
                continue

            cities.append(c)
            values.append(v)

        values = np.stack(values)

        cities = [torch.tensor(c, dtype=torch.float32) for c in cities]
        values = torch.tensor(values, dtype=torch.float32)

        match (min_cities, max_cities):
            case None, None:
                name = "TSPLIB"
            case int(), None:
                name = f"TSPLIB > {min_cities}"
            case None, int():
                name = f"TSPLIB < {max_cities}"
            case int(), int():
                name = f"TSPLIB {min_cities}~{max_cities}"

        return cls(name, cities, values)
