import json
import os
import os.path as osp
import zipfile
from typing import Literal

import numpy as np
import pandas as pd
import requests
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from torch.utils.data import Dataset
from tqdm import tqdm

from build_arc_dataset import DataProcessConfig as ArcConfig
from build_arc_dataset import main as preprocess_arc_data
from build_maze_dataset import DataProcessConfig as MazeConfig
from build_maze_dataset import preprocess_data as preprocess_maze_data
from build_sudoku_dataset import DataProcessConfig as SudokuConfig
from build_sudoku_dataset import preprocess_data as preprocess_sudoku_data


class SudokuDataset(Dataset):
    def __init__(self, dataset_dir: str, split: str):

        os.makedirs(dataset_dir, exist_ok=True)
        filename_features = os.path.join(dataset_dir, "sudoku", "features.pt")
        filename_labels = os.path.join(dataset_dir, "sudoku", "labels.pt")
        if os.path.exists(filename_features) and os.path.exists(filename_labels):
            with open(filename_features, "rb") as f:
                self.features = torch.load(f)
            with open(filename_labels, "rb") as f:
                self.labels = torch.load(f)
        else:
            print("Download sudoku dataset...")
            res = requests.get("https://powei.tw/sudoku.zip")
            with open(os.path.join(dataset_dir, f"{split}.zip"), "wb") as f:
                f.write(res.content)
            with zipfile.ZipFile(
                os.path.join(dataset_dir, f"{split}.zip"), "r"
            ) as zip_ref:
                zip_ref.extractall(dataset_dir)
            os.remove(os.path.join(dataset_dir, f"{split}.zip"))

        with open(filename_features, "rb") as f:
            self.features = torch.load(f)
        with open(filename_labels, "rb") as f:
            self.labels = torch.load(f)

        self.is_input = self.features.sum(dim=3).bool()
        self.features = self.features.argmax(dim=-1)  # [10000, 9, 9] {1..9}
        self.labels = self.labels.argmax(dim=-1)  # [10000, 9, 9] {0..8}
        self.features[~self.is_input] = 9
        indices = (
            torch.arange(0, 9000) if split == "train" else torch.arange(9000, 10000)
        )

        self.features = self.features[indices]
        self.labels = self.labels[indices]
        self.is_input = self.is_input[indices]
        self.group_id = torch.arange(
            len(self.features)
        )  # dummy group id (all different)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return (
            self.features[idx],
            self.labels[idx],
            self.is_input[idx],
            self.group_id[idx],  # dummy group id (all different)
        )


class HardSudokuDataset(Dataset):
    def __init__(self, dataset_dir: str, split: Literal["train", "valid", "test"]):
        os.makedirs(dataset_dir, exist_ok=True)
        filename = os.path.join(dataset_dir, "sudoku-hard", f"{split}.csv")
        if not os.path.exists(filename):
            print("Downloading sudoku-hard dataset...")
            res = requests.get(
                "https://www.dropbox.com/s/rp3hbjs91xiqdgc/sudoku-hard.zip?dl=1"
            )
            with open(os.path.join(dataset_dir, f"{split}.zip"), "wb") as f:
                f.write(res.content)
            with zipfile.ZipFile(
                os.path.join(dataset_dir, f"{split}.zip"), "r"
            ) as zip_ref:
                zip_ref.extractall(dataset_dir)
            os.remove(os.path.join(dataset_dir, f"{split}.zip"))
        df = pd.read_csv(filename, header=None)

        features = list()
        labels = list()
        for i in tqdm(range(len(df))):
            inp = df.iloc[i][0]
            out = df.iloc[i][1]
            features.append(self.str2onehot(inp))
            labels.append(self.str2onehot(out))

        self.features = torch.tensor(np.array(features))
        self.labels = torch.tensor(np.array(labels))

        self.is_input = self.features.sum(dim=3).bool()
        self.features = self.features.argmax(dim=-1)  # 0-8
        self.labels = self.labels.argmax(dim=-1)  # 0-8
        self.features[~self.is_input] = 9
        self.group_id = torch.arange(
            len(self.features)
        )  # dummy group id (all different)

    @staticmethod
    def str2onehot(x):
        x = np.array(list(map(int, x)), dtype="int64")
        y = np.zeros((len(x), 9), dtype="float32")
        idx = np.where(x > 0)[0]
        y[idx, x[idx] - 1] = 1
        return y.reshape((9, 9, 9))

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return (
            self.features[idx],
            self.labels[idx],
            self.is_input[idx],
            self.group_id[idx],  # dummy group id (all different
        )


def convert_onehot_to_int(X):
    # [B, H, W, 9]->[B, H, W]
    is_input = X.sum(dim=-1)
    return (is_input * (X.argmax(-1) + 1)).to(torch.int32)


# copied from https://github.com/yilundu/ired_code_release/blob/3d74b85fab7fcf5e28aaf15e9ed3bf51c1a1d545/sat_dataset.py#L17
def load_rrn_dataset(data_dir, split):
    if not osp.exists(data_dir):
        raise ValueError(
            f"Data directory {data_dir} does not exist. Run data/download-rrn.sh to download the dataset."
        )

    split_to_filename = {"train": "train.csv", "val": "valid.csv", "test": "test.csv"}

    filename = osp.join(data_dir, split_to_filename[split])
    df = pd.read_csv(filename, header=None)

    def str2onehot(x):
        x = np.array(list(map(int, x)), dtype="int64")
        y = np.zeros((len(x), 9), dtype="float32")
        idx = np.where(x > 0)[0]
        y[idx, x[idx] - 1] = 1
        return y.reshape((9, 9, 9))

    features = list()
    labels = list()
    for i in range(len(df)):
        inp = df.iloc[i][0]
        out = df.iloc[i][1]
        features.append(str2onehot(inp))
        labels.append(str2onehot(out))

    return torch.tensor(np.array(features)), torch.tensor(np.array(labels))


def load_sat_dataset(path):
    with open(os.path.join(path, "features.pt"), "rb") as f:
        X = torch.load(f)
    with open(os.path.join(path, "labels.pt"), "rb") as f:
        Y = torch.load(f)
    with open(os.path.join(path, "perm.pt"), "rb") as f:
        perm = torch.load(f)
    return X, Y, perm


class ExtremeSudokuDataset(Dataset):
    def __init__(self, dataset_dir: str, split: Literal["train", "valid", "test"]):
        os.makedirs(dataset_dir, exist_ok=True)
        if split in ["train", "valid"]:
            filename = "train"
        else:
            filename = "test"
        filepath = os.path.join(dataset_dir, "sudoku-extreme", f"{filename}.csv")
        if not os.path.exists(filepath):
            res = hf_hub_download(
                "sapientinc/sudoku-extreme", f"{filename}.csv", repo_type="dataset"
            )
            df = pd.read_csv(res)
            df.to_csv(filepath)
        df = pd.read_csv(filepath)
        df["question"] = df["question"].str.replace(".", "0")
        df = df[["question", "answer"]]
        df.columns = [0, 1]

        if split == "train":
            df = df.iloc[: int(len(df) * 0.9)]
        elif split == "valid":
            df = df.iloc[int(len(df) * 0.9) :]

        features = list()
        labels = list()
        for i in tqdm(range(len(df))):
            inp = df.iloc[i][0]
            out = df.iloc[i][1]
            features.append(self.str2onehot(inp))
            labels.append(self.str2onehot(out))

        self.features = torch.tensor(np.array(features))
        self.labels = torch.tensor(np.array(labels))

        self.is_input = self.features.sum(dim=3).bool()
        self.features = self.features.argmax(dim=-1)  # 0-8
        self.labels = self.labels.argmax(dim=-1)  # 0-8
        self.features[~self.is_input] = 9  # 0-9
        self.group_id = torch.arange(len(self.features))

    @staticmethod
    def str2onehot(x):
        x = np.array(list(map(int, x)), dtype="int64")
        y = np.zeros((len(x), 9), dtype="float32")
        idx = np.where(x > 0)[0]
        y[idx, x[idx] - 1] = 1
        return y.reshape((9, 9, 9))

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return (
            self.features[idx],
            self.labels[idx],
            self.is_input[idx],
            self.group_id[idx],  # dummy group id (all different
        )


class ExtremeSudokuAugDataset(Dataset):
    def __init__(self, dataset_dir: str, split: Literal["train", "test"]):
        if not os.path.exists(os.path.join(dataset_dir, "sudoku-extreme-1k-aug-1000")):
            config = SudokuConfig(
                output_dir=os.path.join(dataset_dir, "sudoku-extreme-1k-aug-1000"),
                subsample_size=1000,
                num_aug=1000,
            )
            preprocess_sudoku_data(config)
        with open(
            os.path.join(
                dataset_dir, "sudoku-extreme-1k-aug-1000", split, "all__inputs.npy"
            ),
            "rb",
        ) as f:
            features = np.load(f)
        with open(
            os.path.join(
                dataset_dir, "sudoku-extreme-1k-aug-1000", split, "all__labels.npy"
            ),
            "rb",
        ) as f:
            labels = np.load(f)

        group_indices = np.load(
            os.path.join(
                dataset_dir,
                "sudoku-extreme-1k-aug-1000",
                split,
                "all__group_indices.npy",
            )
        )
        N = features.shape[0]
        group_id = np.zeros(N, dtype=np.int64)
        for i, (start, end) in enumerate(
            tqdm(zip(group_indices[:-1], group_indices[1:]))
        ):
            group_id[start:end] = i
        group_id[group_indices[-1] :] = i + 1

        self.features = torch.tensor(features).reshape(-1, 9, 9) - 1
        self.is_input = self.features != 0
        self.features = self.features - 1
        self.features[~self.is_input] = 9  # 0-9
        self.labels = torch.tensor(labels).reshape(-1, 9, 9) - 2  # 0-8
        self.group_id = torch.from_numpy(group_id).to(torch.int64)  # (N,)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return (
            self.features[idx],
            self.labels[idx],
            self.is_input[idx],
            self.group_id[idx],
        )


class MazeDataset(Dataset):
    def __init__(self, dataset_dir: str, split: Literal["train", "test"]):
        if not os.path.exists(os.path.join(dataset_dir, "maze-30x30-hard-1k")):
            config = MazeConfig()
            preprocess_maze_data(config)
        with open(
            os.path.join(dataset_dir, "maze-30x30-hard-1k", split, "all__inputs.npy"),
            "rb",
        ) as f:
            features = np.load(f)
        with open(
            os.path.join(dataset_dir, "maze-30x30-hard-1k", split, "all__labels.npy"),
            "rb",
        ) as f:
            labels = np.load(f)

        group_indices = np.load(
            os.path.join(
                dataset_dir, "maze-30x30-hard-1k", split, "all__group_indices.npy"
            )
        )
        N = features.shape[0]
        group_id = np.zeros(N, dtype=np.int64)
        for i, (start, end) in enumerate(
            tqdm(zip(group_indices[:-1], group_indices[1:]))
        ):
            group_id[start:end] = i
        group_id[group_indices[-1] :] = i + 1

        self.features = torch.tensor(features).reshape(-1, 30, 30) - 1
        self.is_input = self.features != 1
        self.labels = torch.tensor(labels).reshape(-1, 30, 30) - 1
        self.group_id = torch.from_numpy(group_id).to(torch.int64)  # (N,)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return (
            self.features[idx],
            self.labels[idx],
            self.is_input[idx],
            self.group_id[idx],
        )


class ARCDataset(Dataset):
    def __init__(self, dataset_dir: str, split: Literal["train", "test"]):
        if not os.path.exists(os.path.join(dataset_dir, "arc-aug-1000")):
            print("Creating arc-aug-1000 dataset...")
            config = ArcConfig(
                output_dir=os.path.join(dataset_dir, "arc-aug-1000"),
                dataset_dirs=[
                    "src/dataset/raw-data/ARC-AGI/data",
                    "src/dataset/raw-data/ConceptARC/corpus",
                ],
            )
            preprocess_arc_data(config)
        with open(
            os.path.join(dataset_dir, "arc-aug-1000", split, "all__inputs.npy"), "rb"
        ) as f:
            features = np.load(f)
        with open(
            os.path.join(dataset_dir, "arc-aug-1000", split, "all__labels.npy"), "rb"
        ) as f:
            labels = np.load(f)

        group_indices = np.load(
            os.path.join(dataset_dir, "arc-aug-1000", split, "all__group_indices.npy")
        )

        N = features.shape[0]
        group_id = np.zeros(N, dtype=np.int64)
        for i, (start, end) in enumerate(
            tqdm(zip(group_indices[:-1], group_indices[1:]))
        ):
            group_id[start:end] = i
        group_id[group_indices[-1] :] = i + 1

        self.features = torch.tensor(features).reshape(-1, 30, 30)
        self.is_input = torch.zeros_like(self.features).bool()
        self.labels = torch.tensor(labels).reshape(-1, 30, 30)
        self.group_id = torch.from_numpy(group_id).to(torch.int64)  # (N,)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return (
            self.features[idx],
            self.labels[idx],
            self.is_input[idx],
            self.group_id[idx],
        )


class ARC2Dataset(Dataset):
    def __init__(self, dataset_dir: str, split: Literal["train", "test"]):
        if not os.path.exists(os.path.join(dataset_dir, "arc-2-aug-1000")):
            print("Creating arc-2-aug-1000 dataset...")
            config = ArcConfig(
                output_dir=os.path.join(dataset_dir, "arc-2-aug-1000"),
                dataset_dirs=[
                    "src/dataset/raw-data/ARC-AGI-2/data",
                ],
            )
            preprocess_arc_data(config)
        with open(
            os.path.join(dataset_dir, "arc-2-aug-1000", split, "all__inputs.npy"), "rb"
        ) as f:
            features = np.load(f)
        with open(
            os.path.join(dataset_dir, "arc-2-aug-1000", split, "all__labels.npy"), "rb"
        ) as f:
            labels = np.load(f)

        group_indices = np.load(
            os.path.join(dataset_dir, "arc-2-aug-1000", split, "all__group_indices.npy")
        )
        N = features.shape[0]
        group_id = np.zeros(N, dtype=np.int64)
        for i, (start, end) in enumerate(
            tqdm(zip(group_indices[:-1], group_indices[1:]))
        ):
            group_id[start:end] = i
        group_id[group_indices[-1] :] = i + 1

        self.features = torch.tensor(features).reshape(-1, 30, 30)
        self.is_input = torch.zeros_like(self.features).bool()
        self.labels = torch.tensor(labels).reshape(-1, 30, 30)
        self.group_id = torch.from_numpy(group_id).to(torch.int64)  # (N,)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return (
            self.features[idx],
            self.labels[idx],
            self.is_input[idx],
            self.group_id[idx],
        )


class ARCOrigDataset(Dataset):
    def __init__(self, dataset_dir: str, split: Literal["train", "test"]):
        if not os.path.exists(f"{dataset_dir}/arc-orig/{split}"):
            os.makedirs(f"{dataset_dir}/arc-orig/{split}", exist_ok=True)
            raw_data_dir = os.path.join(
                "src/dataset/raw-data/ARC-AGI/data",
                "training" if split == "train" else "evaluation",
            )
            train_jsons = [f for f in os.listdir(raw_data_dir) if f.endswith(".json")]
            train_dataset = []
            for f in tqdm(train_jsons):
                with open(os.path.join(raw_data_dir, f), "r") as json_file:
                    group_id = hash(f.split(".")[0])
                    json_data = json.load(json_file)
                    # Process json_data as needed
                    for train_data in json_data["train"]:
                        input_tensor = torch.tensor(train_data["input"])
                        h, w = input_tensor.shape
                        input_pad = F.pad(
                            input_tensor,
                            (0, 30 - w, 0, 30 - h),
                        )

                        output_tensor = torch.tensor(train_data["output"])
                        h, w = output_tensor.shape
                        output_pad = F.pad(
                            output_tensor,
                            (0, 30 - w, 0, 30 - h),
                        )

                        output_tensor = torch.tensor(train_data["output"])
                        h, w = output_tensor.shape
                        output_pad = F.pad(
                            output_tensor,
                            (0, 30 - w, 0, 30 - h),
                        )

                        output_tensor = torch.tensor(train_data["output"])
                        h, w = output_tensor.shape
                        output_pad = F.pad(
                            output_tensor,
                            (0, 30 - w, 0, 30 - h),
                        )
                        train_dataset.append([input_pad, output_pad, group_id])
                    input_tensor = torch.tensor(json_data["test"][0]["input"])
                    h, w = input_tensor.shape
                    input_pad = F.pad(
                        input_tensor,
                        (0, 30 - w, 0, 30 - h),
                    )
                    output_tensor = torch.tensor(json_data["test"][0]["output"])
                    h, w = output_tensor.shape
                    output_pad = F.pad(
                        output_tensor,
                        (0, 30 - w, 0, 30 - h),
                    )
                    train_dataset.append([input_pad, output_pad, group_id])
            with open(f"{dataset_dir}/arc-orig/{split}/dataset.pt", "wb") as f:
                torch.save(train_dataset, f)
        else:
            with open(f"{dataset_dir}/arc-orig/{split}/dataset.pt", "rb") as f:
                train_dataset = torch.load(f)
        f = [d[0] for d in train_dataset]
        l = [d[1] for d in train_dataset]
        self.features = torch.stack(f).reshape(-1, 30, 30)
        self.is_input = torch.zeros_like(self.features).bool()
        self.labels = torch.stack(l).reshape(-1, 30, 30)
        self.group_id = torch.arange(len(self.features))

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return (
            self.features[idx],
            self.labels[idx],
            self.is_input[idx],
            self.group_id[idx],
        )


class ARC2OrigDataset(Dataset):
    def __init__(self, dataset_dir: str, split: Literal["train", "test"]):
        if not os.path.exists(f"{dataset_dir}/arc2-orig/{split}"):
            os.makedirs(f"{dataset_dir}/arc2-orig/{split}", exist_ok=True)
            raw_data_dir = os.path.join(
                "src/dataset/raw-data/ARC-AGI-2/data",
                "training" if split == "train" else "evaluation",
            )
            train_jsons = [f for f in os.listdir(raw_data_dir) if f.endswith(".json")]
            train_dataset = []
            for f in tqdm(train_jsons):
                with open(os.path.join(raw_data_dir, f), "r") as json_file:
                    group_id = hash(f.split(".")[0])
                    json_data = json.load(json_file)
                    # Process json_data as needed
                    for train_data in json_data["train"]:
                        input_tensor = torch.tensor(train_data["input"])
                        h, w = input_tensor.shape
                        input_pad = F.pad(
                            input_tensor,
                            (0, 30 - w, 0, 30 - h),
                        )

                        output_tensor = torch.tensor(train_data["output"])
                        h, w = output_tensor.shape
                        output_pad = F.pad(
                            output_tensor,
                            (0, 30 - w, 0, 30 - h),
                        )
                        train_dataset.append([input_pad, output_pad, group_id])
                    input_tensor = torch.tensor(json_data["test"][0]["input"])
                    h, w = input_tensor.shape
                    input_pad = F.pad(
                        input_tensor,
                        (0, 30 - w, 0, 30 - h),
                    )
                    output_tensor = torch.tensor(json_data["test"][0]["output"])
                    h, w = output_tensor.shape
                    output_pad = F.pad(
                        output_tensor,
                        (0, 30 - w, 0, 30 - h),
                    )
                    train_dataset.append([input_pad, output_pad, group_id])
            with open(f"{dataset_dir}/arc2-orig/{split}/dataset.pt", "wb") as f:
                torch.save(train_dataset, f)
        else:
            with open(f"{dataset_dir}/arc2-orig/{split}/dataset.pt", "rb") as f:
                train_dataset = torch.load(f)
        f = [d[0] for d in train_dataset]
        l = [d[1] for d in train_dataset]
        self.features = torch.stack(f).reshape(-1, 30, 30)
        self.is_input = torch.zeros_like(self.features).bool()
        self.labels = torch.stack(l).reshape(-1, 30, 30)
        self.group_id = torch.arange(len(self.features))

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return (
            self.features[idx],
            self.labels[idx],
            self.is_input[idx],
            self.group_id[idx],
        )
