import sys
from pathlib import Path
from typing import Any, List, Optional, Tuple

import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from zendo.game import evaluate, parser
from zendo.utils import get_env, load_envs


class FixedRuleOnlineDataset(Dataset):
    def __init__(self, rule: str, structures: List[str], structure_encoder: Any):
        self.rule = rule
        self.ast = parser.parse(rule)
        self.structures = structures

        self.structure_encoder = structure_encoder

    def __len__(self) -> int:
        return len(self.structures)

    def __getitem__(self, item: int) -> Tuple[torch.tensor, torch.Tensor]:
        structure = self.structures[item]
        label = int(evaluate(self.ast, structure))
        structure = self.structure_encoder.transform(list(structure))
        return torch.as_tensor(structure, dtype=torch.long), torch.as_tensor(label)


class FixedRuleStaticDataset(Dataset):
    def __init__(
        self,
        structures: List[str],
        labels: List[str],
        rule: Optional[str],
        structure_encoder: Any,
    ):
        self.rule = rule
        self.labels = labels
        self.structures = structures
        assert len(self.structures) == len(self.labels)

        self.structure_encoder = structure_encoder

    def __len__(self) -> int:
        return len(self.structures)

    def __getitem__(self, item: int) -> Tuple[torch.tensor, torch.tensor]:
        structure, label = self.structures[item], self.labels[item]
        structure = self.structure_encoder.transform(list(structure))
        label = int(label)
        return torch.as_tensor(structure, dtype=torch.long), torch.as_tensor(label)


def get_train_test_structures(
    dataset_name: str, test_size: float = 0.2, random_state: float = 42
) -> Tuple[List[str], List[str]]:
    root = Path(get_env(dataset_name))
    structures = (root / "structures.txt").read_text().split("\n")
    x_train, x_test = train_test_split(
        structures, test_size=test_size, random_state=random_state
    )
    return x_train, x_test


def get_train_test_labeled_structures(
    dataset_name: str,
    rule_folder_name: str,
    test_size: float = 0.2,
    random_state: float = 42,
) -> Tuple[str, List[str], List[str], List[str], List[str]]:
    root = Path(get_env(dataset_name))
    rule = (root / rule_folder_name / "rule.txt").read_text()
    structures = (root / "structures.txt").read_text().split("\n")
    labels = [
        x for x in (root / rule_folder_name / "labels.txt").read_text().split("\n")
    ]
    X_train, X_test, y_train, y_test = train_test_split(
        structures, labels, test_size=test_size, random_state=random_state
    )
    return rule, X_train, X_test, y_train, y_test


if __name__ == "__main__":
    load_envs()
    dataset_name = "DATASET_S6_startPROP"
    batch_size = 16
    test_size = 0.001
    print("Data loading benchmarking!\n", file=sys.stderr)

    # # ---
    #
    # print("Static data loading speed (loading pre-computed labels)", file=sys.stderr)
    # rule, X_train, X_test, y_train, y_test = get_train_test_labeled_structures(
    #     dataset_name, f"{0:010d}", test_size=test_size
    # )
    #
    # train_dataset = FixedRuleStaticDataset(
    #     structures=X_train, labels=y_train, rule=rule
    # )
    # train_dl = DataLoader(train_dataset, batch_size=16, shuffle=True)
    # for x, y in tqdm(train_dl):
    #     pass
    #
    # # ---
    #
    # print(file=sys.stderr)

    # ---

    print(
        "Online data loading speed (online labels computation with fixed ast)",
        file=sys.stderr,
    )
    x_train, x_test = get_train_test_structures(
        dataset_name=dataset_name, test_size=test_size
    )

    train_dataset = FixedRuleOnlineDataset(rule="at_least 1 blue", structures=x_train)
    train_dl = DataLoader(train_dataset, batch_size=16, shuffle=True)
    for x, y in tqdm(train_dl):
        pass
