import itertools
import json
import random
from pathlib import Path
from typing import Callable, List, Iterable, Tuple, Any

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset as TorchDataset
from collections import Counter, defaultdict
import numpy as np
from itertools import combinations

_DATA_ROOT = Path(__file__).parent
import copy


class ShuffleIterator:
    def __init__(self, iterable):
        self.iterable = iterable
        self.buffer = []
        self.index = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self.index >= len(self.buffer):
            self.buffer = list(self.iterable)
            random.shuffle(self.buffer)
            self.index = 0

        if not self.buffer:
            self.buffer = list(self.iterable)
            random.shuffle(self.buffer)

        item = self.buffer[self.index]
        self.index += 1
        return item


def MNIST_datasets():
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )

    datasets = {
        "train": torchvision.datasets.MNIST(
            root=str(_DATA_ROOT), train=True, download=True, transform=transform
        ),
        "test": torchvision.datasets.MNIST(
            root=str(_DATA_ROOT), train=False, download=True, transform=transform
        ),
    }
    return datasets


def digits_to_number_binary(digits: Iterable[int]) -> int:
    """
    Converts a list of binary digits (0 or 1) to its equivalent integer value.
    """
    number = 0
    for d in digits:
        number *= 2  # Shift left by 1 bit (multiply by base 2)
        number += d  # Add the current binary digit
    return number


def number_to_digits_binary(number: int, digit_size=None) -> List[int]:
    """
    Converts an integer to a list of binary digits (0 or 1).
    If digit_size is provided, pads the result with leading zeros to match the size.
    """
    if digit_size is not None:
        digits = []
        for i in range(digit_size):
            digits.append(number % 2)
            number //= 2
        return digits[::-1]  # Reverse to get the correct order

    if number == 0:
        return [0]

    digits = []
    while number > 0:
        digits.append(number % 2)
        number //= 2

    return digits[::-1]  # Reverse to get the correct order


def digits_to_number(digits, num_classes=2) -> int:
    number = 0
    for d in digits:
        number *= num_classes
        number += d
    return number


def number_to_digits(number: int, digit_size=None, num_classes=2) -> List[int]:
    digits = []
    if digit_size is not None:
        for i in range(digit_size):
            digits.append(number % num_classes)
            number //= num_classes
        # print('digits:',digits[::-1])
        return digits[::-1]

    if number == 0:
        return [0]

    digits = []
    while number > 0:
        digits.append(number % num_classes)
        number //= num_classes

    # The digits are appended in reverse order, so reverse the list before returning
    return digits[::-1]


# def number_to_digits(number: int, digit_size=None) -> List[int]:
#     if number == 0:
#         return [0]

#     digits = []
#     while number > 0:
#         digits.append(number % 10)
#         number //= 10

#     # The digits are appended in reverse order, so reverse the list before returning
#     return digits[::-1]


def chess_data(
    n: int,
    seed=None,
    train: bool = True,
    num_classes=2,
    sequence_num=30000,
):
    """Returns a dataset for binary addition"""
    return DigitsOperator(
        function_name="addition" if n == 1 else "multi_addition",
        operator=sum,
        size=n,
        arity=2,
        seed=seed,
        train=train,
        num_classes=num_classes,
        sequence_num=sequence_num,
    )


class DigitsOperator(TorchDataset):
    def __getitem__(self, index: int) -> Tuple[list, list, int]:
        l = self.data[index]
        inputs = [x[0] for x in l]
        return inputs

    def get_attacked_positions(self, type, x, y):
        if type == 0:
            return self.get_pawn_attacked_positions(x, y)
        elif type == 1:
            return self.get_rook_attacked_positions(x, y)
        elif type == 2:
            return self.get_bishop_attacked_positions(x, y)
        elif type == 3:
            return self.get_knight_attacked_positions(x, y)
        elif type == 4:
            return self.get_king_attacked_positions(x, y)
        elif type == 5:
            return self.get_queen_attacked_positions(x, y)
        return []

    def get_queen_attacked_positions(self, x, y):
        # attacked_positions = set()

        # Queen attacks horizontally and vertically (rook-like)
        attacked_positions = self.get_rook_attacked_positions(
            x, y
        ) + self.get_bishop_attacked_positions(x, y)

        # Queen attacks diagonally (bishop-like)

        return attacked_positions

    def get_rook_attacked_positions(self, x, y):
        attacked_positions = []

        # Rook can attack along the row and column
        for i in range(8):
            if i != x and self.vis[i][y] is not True:
                attacked_positions.append((i, y))  # Same column
            if i != y and self.vis[x][i] is not True:
                attacked_positions.append((x, i))  # Same row

        return attacked_positions

    def get_bishop_attacked_positions(self, x, y):
        attacked_positions = []

        # Bishop attacks diagonally
        for i in range(1, 8):
            if 0 <= x + i < 8 and 0 <= y + i < 8 and self.vis[x + i][y + i] is not True:
                attacked_positions.append((x + i, y + i))
            if 0 <= x + i < 8 and 0 <= y - i < 8 and self.vis[x + i][y - i] is not True:
                attacked_positions.append((x + i, y - i))
            if 0 <= x - i < 8 and 0 <= y + i < 8 and self.vis[x - i][y + i] is not True:
                attacked_positions.append((x - i, y + i))
            if 0 <= x - i < 8 and 0 <= y - i < 8 and self.vis[x - i][y - i] is not True:
                attacked_positions.append((x - i, y - i))

        return attacked_positions

    def get_knight_attacked_positions(self, x, y):
        attacked_positions = []

        # Knight moves in "L" shape
        moves = [(2, 1), (2, -1), (-2, 1), (-2, -1), (1, 2), (1, -2), (-1, 2), (-1, -2)]

        for dx, dy in moves:
            if (
                0 <= x + dx < 8
                and 0 <= y + dy < 8
                and self.vis[x + dx][y + dy] is not True
            ):
                attacked_positions.append((x + dx, y + dy))

        return attacked_positions

    def get_king_attacked_positions(self, x, y):
        attacked_positions = []

        # King moves one square in any direction
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                if dx == 0 and dy == 0:
                    continue
                if (
                    0 <= x + dx < 8
                    and 0 <= y + dy < 8
                    and self.vis[x + dx][y + dy] is not True
                ):
                    attacked_positions.append((x + dx, y + dy))

        return attacked_positions

    def get_pawn_attacked_positions(self, x, y):
        attacked_positions = []

        # Pawn attacks diagonally (assumes white pawn, so moves upwards)
        if 0 <= x - 1 < 8 and 0 <= y - 1 < 8 and self.vis[x - 1][y - 1] is not True:
            attacked_positions.append((x - 1, y - 1))  # Diagonal left
        if 0 <= x - 1 < 8 and 0 <= y + 1 < 8 and self.vis[x - 1][y + 1] is not True:
            attacked_positions.append((x - 1, y + 1))  # Diagonal right

        return attacked_positions

    def generate_sequences(self, n):
        types = []
        positions = []
        all_positions = []
        for x in range(8):
            for y in range(8):
                all_positions.append((x, y))
        for _ in range(n):
            while True:
                self.vis = [[False for i in range(8)] for j in range(8)]
                rest_pos = copy.deepcopy(all_positions)
                rest_type = list(range(self.num_classes))
                cur_type = []
                cur_pos = []
                pre_type = None
                find = True
                for i in range(self.size):
                    type = random.choice(rest_type)
                    if i == 0:
                        pos = random.choice(rest_pos)
                    else:
                        candidate_pos = self.get_attacked_positions(
                            pre_type, pre_pos[0], pre_pos[1]
                        )
                        if len(candidate_pos) == 0:
                            find = False
                            break
                        else:
                            pos = random.choice(candidate_pos)
                    self.vis[pos[0]][pos[1]] = True
                    cur_type.append(type)
                    cur_pos.append(pos)
                    pre_type = type
                    pre_pos = pos
                    rest_pos.remove(pos)
                    # rest_type.remove(type)

                if find is True:
                    types.append(cur_type)
                    positions.append(cur_pos)
                    break
        return types, positions

    def balance_indices(self):
        balance_size = sorted(Counter(self.labels).items())[0][1]
        labels_dist = defaultdict(int)
        sampler_iter = ShuffleIterator(list(range(len(self.dataset))))
        balanced_indices = []
        while len(balanced_indices) < balance_size * self.num_classes:
            sample = next(sampler_iter)
            sampled_class = self.labels[sample]
            if labels_dist[sampled_class] >= balance_size:
                continue
            balanced_indices.append(sample)
            labels_dist[sampled_class] += 1
        return balanced_indices

    def split_dataset_by_category(self, dataset):
        """
        Splits a dataset into groups based on categories.

        :param dataset: A list of tuples where each tuple contains (data, category).
        :return: A dictionary where keys are categories and values are lists of data points.
        """
        categorized_data = defaultdict(list)
        for data, category in dataset:
            categorized_data[category].append(data)
        return dict(categorized_data)

    def indices(self):
        return list(range(len(self.dataset)))

    def __init__(
        self,
        function_name: str,
        operator: Callable[[List[int]], int],
        size=1,
        arity=2,
        seed=None,
        train: bool = True,
        sequence_num: int = 30000,
        num_classes=2
        # shuffle_times=1
    ):
        """Generic dataset for operator(img, img) style datasets.

        :param dataset_name: Dataset to use (train, val, test)
        :param function_name: Name of Problog function to query.
        :param operator: Operator to generate correct examples
        :param size: Size of numbers (number of digits)
        :param arity: Number of arguments for the operator
        :param seed: Seed for RNG
        """
        super(DigitsOperator, self).__init__()
        assert size >= 1
        assert arity >= 1
        self.datasets = MNIST_datasets()
        self.dataset = self.datasets["train" if train else "test"]
        self.function_name = function_name
        self.operator = operator
        self.size = size
        self.arity = arity
        self.seed = seed
        self.num_classes = num_classes
        self.sequence_num = sequence_num
        self.all_list = []
        self.all_list, self.all_pos = self.generate_sequences(self.sequence_num)
        self.categorized_data = [
            self.split_dataset_by_category(self.dataset)[c] for c in range(num_classes)
        ]
        self.dataset = []
        self.labels = []
        for c in range(num_classes):
            self.dataset.extend(
                [
                    (self.categorized_data[c][i], c)
                    for i in range(len(self.categorized_data[c]))
                ]
            )
            self.labels.extend([c for i in range(len(self.categorized_data[c]))])
        self.data = []
        self.counts = np.zeros(num_classes)
        try:
            for sequence, pos in zip(self.all_list, self.all_pos):
                inputs = []
                for c, p in zip(sequence, pos):
                    self.counts[c] += 1
                    inputs.append((random.choice(self.categorized_data[c]), c, p))
                # inputs.append([(random.choice(self.categorized_data[c]),c) for c in digits_result])
                self.data.append(inputs)
        except StopIteration:
            pass
        self.prior = self.counts / np.sum(self.counts)
        # print('self.prior:',self.prior)

    def to_file_repr(self, i):
        """Old file represenation dump. Not a very clear format as multi-digit arguments are not separated"""
        return f"{tuple(itertools.chain(*self.data[i]))}"  # \t{self._get_label(i)}"

    def to_json(self):
        """
        Convert to JSON, for easy comparisons with other systems.

        Format is [EXAMPLE, ...]
        EXAMPLE :- [ARGS, expected_result]
        ARGS :- [MULTI_DIGIT_NUMBER, ...]
        MULTI_DIGIT_NUMBER :- [mnist_img_id, ...]
        """
        data = [(self.data[i], self._get_label(i)) for i in range(len(self))]
        return json.dumps(data)

    def _get_pos(self, i: int):
        l = self.data[i]
        pos = [x[2] for x in l]
        return pos

    def _get_symbol_label(self, i: int):
        l = self.data[i]
        ground_truth = [x[1] for x in l]
        return ground_truth

    def __len__(self):
        return len(self.data)


def get_chess(
    train=True,
    n=1,
    num_classes=2,
    sequence_num=30000,
    seed=None,
):
    mnistDataset = chess_data(
        n,
        train=train,
        num_classes=num_classes,
        sequence_num=sequence_num,
        seed=seed,
    )
    X, Z, P = [], [], []
    for idx in range(len(mnistDataset)):
        x = mnistDataset[idx]
        z = mnistDataset._get_symbol_label(idx)
        p = mnistDataset._get_pos(idx)
        X.append(x), Z.append(z), P.append(p)
    # print('Z:', Z)
    return (X, Z, P), mnistDataset.prior


if __name__ == "__main__":
    mnist_add = get_chess()
