import csv
import os
import urllib.request
import zipfile
from copy import copy

import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.dataset import Dataset

import dgl


def _basic_sudoku_graph():
    grids = [
        [0, 1, 2, 9, 10, 11, 18, 19, 20],
        [3, 4, 5, 12, 13, 14, 21, 22, 23],
        [6, 7, 8, 15, 16, 17, 24, 25, 26],
        [27, 28, 29, 36, 37, 38, 45, 46, 47],
        [30, 31, 32, 39, 40, 41, 48, 49, 50],
        [33, 34, 35, 42, 43, 44, 51, 52, 53],
        [54, 55, 56, 63, 64, 65, 72, 73, 74],
        [57, 58, 59, 66, 67, 68, 75, 76, 77],
        [60, 61, 62, 69, 70, 71, 78, 79, 80],
    ]
    edges = set()
    for i in range(81):
        row, col = i // 9, i % 9
        # same row and col
        row_src = row * 9
        col_src = col
        for _ in range(9):
            edges.add((row_src, i))
            edges.add((col_src, i))
            row_src += 1
            col_src += 9
        # same grid
        grid_row, grid_col = row // 3, col // 3
        for n in grids[grid_row * 3 + grid_col]:
            if n != i:
                edges.add((n, i))
    edges = list(edges)
    g = dgl.graph(edges)
    return g


class ListDataset(Dataset):
    def __init__(self, *lists_of_data):
        assert all(len(lists_of_data[0]) == len(d) for d in lists_of_data)
        self.lists_of_data = lists_of_data

    def __getitem__(self, index):
        return tuple(d[index] for d in self.lists_of_data)

    def __len__(self):
        return len(self.lists_of_data[0])


def _get_sudoku_dataset(segment="train"):
    assert segment in ["train", "valid", "test"]
    url = "https://data.dgl.ai/dataset/sudoku-hard.zip"
    zip_fname = "/tmp/sudoku-hard.zip"
    dest_dir = "/tmp/sudoku-hard/"

    if not os.path.exists(dest_dir):
        print("Downloading data...")

        urllib.request.urlretrieve(url, zip_fname)
        with zipfile.ZipFile(zip_fname) as f:
            f.extractall("/tmp/")

    def read_csv(fname):
        print("Reading %s..." % fname)
        with open(dest_dir + fname) as f:
            reader = csv.reader(f, delimiter=",")
            return [(q, a) for q, a in reader]

    data = read_csv(segment + ".csv")

    def encode(samples):
        def parse(x):
            return list(map(int, list(x)))

        encoded = [(parse(q), parse(a)) for q, a in samples]
        return encoded

    data = encode(data)
    print(f"Number of puzzles in {segment} set : {len(data)}")

    return data


def sudoku_dataloader(batch_size, segment="train"):
    """
    Get a DataLoader instance for dataset of sudoku. Every iteration of the dataloader returns
    a DGLGraph instance, the ndata of the graph contains:
    'q': question, e.g. the sudoku puzzle to be solved, the position is to be filled with number from 1-9
         if the value in the position is 0
    'a': answer, the ground truth of the sudoku puzzle
    'row': row index for each position in the grid
    'col': column index for each position in the grid
    :param batch_size: Batch size for the dataloader
    :param segment: The segment of the datasets, must in ['train', 'valid', 'test']
    :return: A pytorch DataLoader instance
    """
    data = _get_sudoku_dataset(segment)
    q, a = zip(*data)

    dataset = ListDataset(q, a)
    if segment == "train":
        data_sampler = RandomSampler(dataset)
    else:
        data_sampler = SequentialSampler(dataset)

    basic_graph = _basic_sudoku_graph()
    sudoku_indices = np.arange(0, 81)
    rows = sudoku_indices // 9
    cols = sudoku_indices % 9

    def collate_fn(batch):
        graph_list = []
        for q, a in batch:
            q = torch.tensor(q, dtype=torch.long)
            a = torch.tensor(a, dtype=torch.long)
            graph = copy(basic_graph)
            graph.ndata["q"] = q  # q means question
            graph.ndata["a"] = a  # a means answer
            graph.ndata["row"] = torch.tensor(rows, dtype=torch.long)
            graph.ndata["col"] = torch.tensor(cols, dtype=torch.long)
            graph_list.append(graph)
        batch_graph = dgl.batch(graph_list)
        return batch_graph

    dataloader = DataLoader(
        dataset, batch_size, sampler=data_sampler, collate_fn=collate_fn
    )
    return dataloader
