import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
import utils

import os
######################################################################

class Lower_TSPDataset(Dataset):
    def __init__(
        self,
        size=50,
        node_dim=2,
        num_samples=100000,
        data_distribution="uniform",
        data_path=None,
    ):
        super(Lower_TSPDataset, self).__init__()
        if data_distribution == "uniform":
            self.data = torch.rand(num_samples, size, node_dim).cpu()
        elif data_distribution == "normal":
            self.data = torch.randn(num_samples, size, node_dim).cpu()
        self.size = num_samples
        if not data_path is None:
            if data_path.split(".")[-1] == "tsp":
                self.data, data_raw = readTSPLib(data_path)
                opt_path = data_path.replace(".tsp", ".opt.tour")
                print(opt_path)
                if os.path.exists(opt_path):
                    self.opt_route = readTSPLibOpt(opt_path)
                    tmp = np.roll(self.opt_route, -1)
                    d = data_raw[0, self.opt_route] - data_raw[0, tmp]
                    self.opt = np.linalg.norm(d, axis=-1).sum()
                else:
                    self.opt = -1
                self.data = data_raw

            else:
                self.data = readDataFile(data_path)
            self.size = self.data.shape[0]
        # print(self.data.shape)

    def __len__(self):
        return self.size
    

    def __getitem__(self, idx):
        return self.data[idx]


class Lower_TSPTunnel(Dataset):
    def __init__(
        self,
        size=50,
        tunnels = 15,
        node_dim=2,
        num_samples=100000,
        data_path=None,
    ):
        super(Lower_TSPTunnel, self).__init__()
        assert 2*tunnels <= size
        datas = utils.generate_random_order(num_samples, size, 2*tunnels)
        self.data = torch.tensor(datas.reshape(num_samples,tunnels, node_dim))
        self.size = num_samples

    def __len__(self):
        return self.size
    

    def __getitem__(self, idx):
        return self.data[idx]


def readDataFile(filePath):
    """
        read validation dataset from "https://github.com/Spider-scnu/TSP"
    """
    res = []
    with open(filePath, "r") as fp:
        datas = fp.readlines()
        for data in datas:
            data = [float(i) for i in data.split("o")[0].split()]
            loc_x = torch.FloatTensor(data[::2])
            loc_y = torch.FloatTensor(data[1::2])
            data = torch.stack([loc_x, loc_y], dim=1)
            res.append(data)
    res = torch.stack(res, dim=0)
    return res


def readTSPLib(filePath):
    """
        read TSPLib
    """
    data_trans, data_raw = [], []
    with open(filePath, "r") as fp:
        loc_x = []
        loc_y = []
        datas = fp.readlines()
        for data in datas:
            if ":" in data or "EOF" in data or "NODE_COORD_SECTION" in data:
                continue
            data = [float(i) for i in data.split()]
            if len(data) == 3:
                loc_x.append(data[1])
                loc_y.append(data[2])
        loc_x = torch.FloatTensor(loc_x)
        loc_y = torch.FloatTensor(loc_y)

        data = torch.stack([loc_x, loc_y], dim=1)
        data_raw.append(data)

        mx = loc_x.max() - loc_x.min()
        my = loc_y.max() - loc_y.min()
        data = torch.stack([loc_x - loc_x.min(), loc_y - loc_y.min()], dim=1)
        data = data / max(mx, my)
        data_trans.append(data)

    data_trans = torch.stack(data_trans, dim=0)
    data_raw = torch.stack(data_raw, dim=0)
    return data_trans, data_raw

def readTSPLibOpt(opt_path):
    with open(opt_path, "r") as fp:
        datas = fp.readlines()
        tours = []
        for data in datas:
            if ":" in data or "-1" in data or "TOUR_SECTION" in data or "EOF" in data:
                continue
            tours.extend([int(i) - 1 for i in data.split()])
        tours = np.array(tours, dtype=np.uint8)
    return tours

if __name__ == "__main__":
    dataset = Lower_TSPDataset(data_path="./data/TSPLIB-master/res/att48.tsp")
    print(dataset)
    print(dataset.__len__())
    print(dataset.__getitem__(0))