from ffcv.writer import DatasetWriter
from ffcv.fields import NDArrayField, FloatField, TorchTensorField, BytesField

from typing import Union, List, Tuple, Dict, Any


from pathlib import Path

from shrp.git_re_basin.git_re_basin import PermutationSpec

from shrp.datasets.dataset_tokens_trojai import DatasetTokens
from shrp.datasets.augmentations import (
    WindowCutter,
    PermutationAugmentation,
    CheckpointAugmentationPipeline,
)

from shrp.datasets.dataset_auxiliaries import (
    tokenize_checkpoint,
)

import logging

import json
import torch
import numpy as np


def prepare_ffcv_dataset(
    dataset_target_path: Union[str, Path],
    zoo_path: Union[list, str, Path],
    permutation_spec: PermutationSpec,
    map_to_canonical: bool = True,
    standardize: bool = True,
    ds_split: list = [0.7, 0.15, 0.15],
    splits: list = ["train", "val", "test"],
    max_samples: int = 1000,
    weight_threshold: int = 15,
    property_keys: dict = {
        "result_keys": [
            "test_acc",
            "training_iteration",
            "ggap",
        ],
        "config_keys": [],
    },
    filter_fn: Any = None,
    num_threads: int = 12,
    shuffle_path: bool = True,
    windowsize: int = 160,
    supersample: Union[str, int] = "auto",
    precision: int = 16,
    ignore_bn: bool = True,
    tokensize: int = 576,
    permutations_per_sample_train: int = 0,
    permutations_per_sample_test: int = 0,
    page_size: int = 4 * 1 << 21,
    drop_pt_dataset: bool = False,
    reference_model_path: Union[Path, str] = None,
):
    """
    Prepares an ffcv dataset from token_dataset.
    Args:
        dataset_target_path: Path to the target dataset.
        zoo_path: Path to the zoo.
        permutation_spec: PermutationSpec to use.
        map_to_canonical: Whether to map models to canonical from using git-rebasin.
        standardize: Whether to standardize the weights (per layer).
        ds_split: Dataset split, in "train" "val" "test".
        max_samples: Maximum number of samples, split by model path to prevent leakage, distributed over splits.
        weight_threshold: Weight threshold in 1-norm.
        property_keys: Property keys (load properties).
        filter_fn: function to filter out models with
        num_threads: Number of threads.
        shuffle_path: Whether to shuffle the path.
        supersample: Supersample.
        ignore_bn: weather to load batchnorm paramters
        tokensize: set dimension of tokens. set to 0 to discover size.
        permutations_per_sample_train: Number of permutations per sample to use for training.
        permutations_per_sample_test: Number of permutations per sample to use for testing.
        page_size: ffcv Page size, see below.
        drop_pt_dataset: flag wheater to write out the dataset.pt torch.utils.data.Dataset type dataset as pickle as well.
        reference_model_path: path to reference model to use for model alignment.
    Returns:
        None
    """

    # load conventional datasets

    for split_key in splits:
        permutations_per_sample = (
            permutations_per_sample_train
            if split_key == "train"
            else permutations_per_sample_test
        )
        preprocess_single_split(
            dataset_target_path=dataset_target_path,
            zoo_path=zoo_path,
            permutation_spec=permutation_spec,
            map_to_canonical=map_to_canonical,
            permutations_per_sample=permutations_per_sample,
            standardize=standardize,
            ds_split=ds_split,
            max_samples=max_samples,
            weight_threshold=weight_threshold,
            property_keys=property_keys,
            filter_fn=filter_fn,
            num_threads=num_threads,
            shuffle_path=shuffle_path,
            windowsize=windowsize,
            supersample=supersample,
            precision=precision,
            split=split_key,
            ignore_bn=ignore_bn,
            tokensize=tokensize,
            page_size=page_size,
            drop_pt_dataset=drop_pt_dataset,
            reference_model_path=reference_model_path,
        )


def preprocess_single_split(
    dataset_target_path: Union[str, Path],
    zoo_path: Union[list, str, Path],
    permutation_spec: PermutationSpec,
    map_to_canonical: bool = True,
    permutations_per_sample: int = 0,
    standardize: bool = True,
    ds_split: list = [0.7, 0.15, 0.15],
    max_samples: int = 1000,
    weight_threshold=15,
    property_keys: dict = {
        "result_keys": [
            "test_acc",
            "training_iteration",
            "ggap",
        ],
        "config_keys": [],
    },
    filter_fn: Any = None,
    num_threads: int = 12,
    shuffle_path: bool = True,
    windowsize: int = 160,
    supersample: Union[str, int] = "auto",
    precision: str = "16",
    split: str = "train",
    ignore_bn: bool = True,
    tokensize: int = 576,
    page_size: int = 4 * 1 << 21,
    drop_pt_dataset: bool = False,
    reference_model_path: Union[Path, str] = None,
):
    """
    Loads a single split of the token dataset and writes to ffcv.
    Args:
        dataset_target_path: Path to the target dataset.
        zoo_path: Path to the zoo.
        permutation_spec: PermutationSpec to use.
        map_to_canonical: Whether to map models to canonical from using git-rebasin.
        permutations_per_sample: Number of permutations per sample (each sample is a stack of several permuted versions).
        standardize: Whether to standardize the weights (per layer).
        ds_split: Dataset split, in "train" "val" "test".
        max_samples: Maximum number of samples, split by model path to prevent leakage, distributed over splits.
        weight_threshold: Weight threshold in 1-norm.
        property_keys: Property keys (load properties).
        filter_fn: function to filter out models with
        num_threads: Number of threads.
        shuffle_path: Whether to shuffle the path.
        windowsize: Windowsize.
        supersample: Supersample.
        split: Split to use.
        ignore_bn: weather to load batchnorm paramters
        tokensize: set dimension of tokens. set to 0 to discover size.
        page_size: ffcv paramter size of page to use for mmap used in dataset writer
        drop_pt_dataset: weather or not to aditionally write the torch.utils.Dataset type as .pt file as well.

    Returns:
        None
    """
    # check existance of dataset parth
    Path(dataset_target_path).mkdir(parents=True, exist_ok=True)

    # check type of permutation_spec
    if callable(permutation_spec):
        permutation_spec = permutation_spec()

    if isinstance(zoo_path, list):
        root = [Path(pdx).absolute() for pdx in zoo_path]
    else:
        root = Path(zoo_path).absolute()

    logging.info("Load token dataset")
    dataset = DatasetTokens(
        root=root,
        mode="tokenize",
        permutation_spec=permutation_spec,
        map_to_canonical=map_to_canonical,
        standardize=standardize,
        train_val_test=split,  # determines which dataset split to use
        ds_split=ds_split,  #
        max_samples=max_samples,
        weight_threshold=weight_threshold,
        precision=precision,
        filter_function=filter_fn,  # gets sample path as argument and returns True if model needs to be filtered out
        property_keys=property_keys,
        num_threads=12,
        shuffle_path=True,
        verbosity=3,
        getitem="tokens+props",
        ignore_bn=ignore_bn,
        tokensize=tokensize,
        reference_model_path=reference_model_path,
    )

    # drop pt file
    if drop_pt_dataset:
        write_path = Path(dataset_target_path).joinpath(f"dataset_{split}.pt")
        try:
            torch.save(dataset, write_path)
        except Exception as e:
            logging.error(f"could not save dataset.pt at {write_path}")
            logging.error(e)

    # set windowcutter transform
    logging.info("set augmentations before  ffcv dataset")
    if permutations_per_sample > 0:
        logging.info("augmentations: prepare permutations")
        dataset.transforms = CheckpointAugmentationPipeline(
            perm_spec=permutation_spec,
            tokensize=dataset.tokensize,
            permutation_number=permutations_per_sample,
            windowsize=windowsize,
            ignore_bn=ignore_bn,
        )
    else:
        logging.info("augmentations: prepare windowcutter")
        dataset.transforms = WindowCutter(windowsize=windowsize)

    ddx, mdx, pos = tokenize_checkpoint(
        dataset.data[0], tokensize, return_mask=True, ignore_bn=ignore_bn
    )

    # set supersample
    if supersample == "auto":
        # infer number of iterations over each sample as len of token sequence divided by windowsiz
        supersample = ddx.shape[-2] // windowsize
    logging.info(f"set to supersample: {supersample}")

    # set supersample in the dataset
    dataset.supersample = supersample

    logging.info(f"dataset len: {len(dataset)}")

    # get max_positions
    max_positions = pos.max(dim=0).values.tolist()

    # cast properties to numbers
    logging.info("cast downstream tasks properties to numbers")
    dataset, label_dict = cast_properties_to_numbers(dataset)

    # get sample and infer dimensions
    logging.info("get sample and infer dimensions")
    # ddx, mask, pos = dataset.__getitem__(0)
    ddx, mask, pos, props = dataset.__getitem__(0)

    logging.info(f"ddx.shape: {ddx.shape} - dtype: {ddx.dtype}")
    logging.info(f"mask.shape: {mask.shape} - dtype: {mask.dtype}")
    logging.info(f"pos.shape: {pos.shape} - dtype: {pos.dtype}")
    logging.info(f"props.shape: {props.shape} - dtype: {props.dtype}")

    # configure writer
    logging.info("configure ffcv writer")
    # """
    write_path = Path(dataset_target_path).joinpath(f"dataset_beton.{split}")
    writer = DatasetWriter(
        write_path,
        {
            "w": TorchTensorField(
                shape=ddx.shape, dtype=ddx.dtype
            ),  # torch.float32 or 16
            "m": TorchTensorField(shape=mask.shape, dtype=mask.dtype),  # torch.bool
            "p": TorchTensorField(shape=pos.shape, dtype=pos.dtype),  # torch.int16
            "props": TorchTensorField(
                shape=props.shape, dtype=props.dtype
            ),  # torch.float32
        },
        page_size=page_size,
        num_workers=16,
    )
    # write dataset
    logging.info("write ffcv dataset to disk")
    writer.from_indexed_dataset(dataset)

    # drop info
    logging.info("collect info and write to disk")
    info = {
        "zoo_path": str(zoo_path),
        "num_samples": dataset._len,
        "supersample": supersample,
        "properties": list(dataset.properties.keys()),
        "map_to_canonical": map_to_canonical,
        "permutations_per_sample": permutations_per_sample,
        "standardize": standardize,
        "ds_split": ds_split,
        "max_samples": max_samples,
        "weight_threshold": weight_threshold,
        "property_keys": property_keys,
        "num_threads": num_threads,
        "shuffle_path": shuffle_path,
        "windowsize": windowsize,
        "split": split,
        "max_positions": max_positions,
        "properties_labels": label_dict,
    }
    # add info json to the same path
    json_path = Path(dataset_target_path).joinpath(f"dataset_info_{split}.json")
    json.dump(info, json_path.open("w"))
    # """

    # test dataset
    logging.info("test dataset with dataloader")
    from ffcv.loader import Loader, OrderOption

    batch_size = 6
    num_workers = 4
    ordering = OrderOption.QUASI_RANDOM
    # Dataset ordering
    loader = Loader(
        write_path,
        batch_size=batch_size,
        num_workers=num_workers,
        order=ordering,
        drop_last=True,
        # pipelines=PIPELINES
        os_cache=False,
    )
    for idx, (ddx, mask, pos, props) in enumerate(loader):
        print(f"ddx.shape: {ddx.shape} - dtype: {ddx.dtype}")
        print(f"mask.shape: {mask.shape} - dtype: {mask.dtype}")
        print(f"pos.shape: {pos.shape} - dtype: {pos.dtype}")
        print(f"props.shape: {props.shape} - dtype: {props.dtype}")
        if idx == 10:
            break


def cast_properties_to_numbers(dataset):
    label_dict = {}
    for key in dataset.properties.keys():
        # keep numbers
        if (
            isinstance(dataset.properties[key][0], int)
            or isinstance(dataset.properties[key][0], float)
            or isinstance(dataset.properties[key][0], np.ndarray)
        ):
            continue
        elif isinstance(dataset.properties[key][0], torch.Tensor):
            if (
                dataset.properties[key][0].dtype == torch.float
                or dataset.properties[key][0].dtype == torch.int
            ):
                continue
        # get unique classes in trainset
        classes = list(np.unique(dataset.properties[key]))
        # generate new labels
        labels = torch.tensor(
            [
                float(classes.index(vdx))
                for idx, vdx in enumerate(dataset.properties[key])
            ]
        ).long()
        # replace labels in properties
        dataset.properties[key] = labels
        # save label dict
        label_dict[key] = classes
    return dataset, label_dict
