import logging
import os
from typing import Dict, List, Tuple

import awkward  # type: ignore
import numpy as np  # type: ignore
import pandas as pd  # type: ignore
import torch
from torch.utils.data import Dataset

__all__ = ["Jets", "AtlasJets"]


T = torch.Tensor


def pad_array(a: np.array, maxlen: int, value: float = 0., dtype: str = 'float32') -> np.array:
    x = (np.ones((len(a), maxlen)) * value).astype(dtype)
    for idx, s in enumerate(a):
        if not len(s):
            continue
        trunc = s[:maxlen].astype(dtype)
        x[idx, :len(trunc)] = trunc
    return x


class Jets(Dataset):
    """
    NOTE:
     - this dataset was used in this paper: https://arxiv.org/pdf/1902.08570.pdf (Jet Tagging via Particle Clouds)
     - the above paper regerences another paper which uses DeepSets for the same dataset: https://arxiv.org/pdf/1810.05165.pdf (Energy Flow Networks: DeepSets For Particle Jets)
     - the dataset preprocessing script can be found here: https://github.com/hqucms/ParticleNet/tree/master/tf-keras
    """
    def __init__(self, root: str, split: str = "train", pad_len: int = 100):
        super().__init__()
        datasets = {
            "train": "particle-jet/converted/train_file_0.awkd",
            "val": "particle-jet/converted/val_file_0.awkd",
            "test": "particle-jet/converted/test_file_0.awkd"
        }

        self.filepath = os.path.join(root, datasets[split])

        self.feature_dict: Dict[str, List[str]] = {}
        if len(self.feature_dict) == 0:
            self.feature_dict['points'] = ['part_etarel', 'part_phirel']
            self.feature_dict['features'] = ['part_pt_log', 'part_e_log', 'part_etarel', 'part_phirel']
            self.feature_dict['mask'] = ['part_pt_log']
        self.label = "label"
        self.pad_len = pad_len
        self.stack_axis = -1
        self._values: Dict[str, np.array] = {}
        self._label: np.array = None
        self.name = "Jets"
        self._load()

        self.x = torch.cat((torch.from_numpy(self._values["points"]), torch.from_numpy(self._values["features"])), dim=-1)
        self.y = torch.argmax(torch.from_numpy(self._label), dim=-1, keepdim=True)

    def _load(self) -> None:
        logging.info('Start loading file %s' % self.filepath)
        counts = None

        with awkward.load(self.filepath) as a:
            self._label = a[self.label]
            for k in self.feature_dict:
                cols = self.feature_dict[k]
                if not isinstance(cols, (list, tuple)):
                    cols = [cols]  # type: ignore

                arrs = []
                for col in cols:
                    if counts is None:
                        counts = a[col].counts
                    else:
                        assert np.array_equal(counts, a[col].counts)
                    arrs.append(pad_array(a[col], self.pad_len))
                self._values[k] = np.stack(arrs, axis=self.stack_axis)
        logging.info('Finished loading file %s' % self.filepath)

    def __len__(self) -> int:
        return len(self._label)

    def __getitem__(self, i: int) -> Tuple[T, T]:
        return self.x[i], self.y[i]


col_names = ["hwid", "idx", "x", "y", "z", "r", "eta", "phi", "raw", "pid", "n", "truth_eta", "truth_phi", "truth_pt", "trk_good", "trk_barcode", "trk_pt"]

col_dtype = {
    "hwid": np.int64, "idx": np.int32, "x": np.float32, "y": np.float32, "z": np.float32,
    "r": np.float32, "eta": np.float32, "phi": np.float32, "raw": np.float32, "pid": np.int32,
    "n": np.int32, "truth_eta": np.float32, "truth_phi": np.float32, "truth_pt": np.float32, "trk_good": np.float32,
    "trk_barcode": np.float32, "trk_pt": np.float32,
}


class AtlasJets(Dataset):
    """
    NOTE:
     - this dataset comes from CERN where there is an accompanying paper comparing some models on it.
     - https://opendata.cern.ch/record/15009
     - paper: http://cds.cern.ch/record/2753414/files/ATL-PHYS-PUB-2021-002.pdf
    """
    def __init__(self, root: str, split: str = "train", pad_len: int = 100):
        super().__init__()
        self.path = os.path.join(root, "zej-particle-jets")
        split_name = {"train": "train.txt", "test": "test.txt"}[split]
        files = open(os.path.join(self.path, "splits", split_name)).readlines()
        self.files = [v[:-1] for v in files]  # get rid of the newline character

        self.means = torch.from_numpy(np.loadtxt(os.path.join(root, "zej-particle-jets", "mean.npy")))
        self.std = torch.from_numpy(np.loadtxt(os.path.join(root, "zej-particle-jets", "std.npy")))
        self.name = "AtlasJets"

        # build a cache for the files which will be filled as the data are called individually
        self.data = {k: (None, None) for k in self.files}

    def __len__(self) -> int:
        return len(self.files)

    def __getitem__(self, i: int) -> Tuple[T, T]:
        file = self.files[i]
        if self.data[file][0] is None:
            # if there is cache miss we should add it to the cache here
            data = pd.read_csv(os.path.join(self.path, "data", file), header=None, names=col_names, dtype=col_dtype, sep="\t")

            y = data.loc[:, "pid"].to_numpy(np.float32)
            x = data.loc[:, "x":"raw"].to_numpy(np.float32)  # type: ignore  # slice itnegers can be strings in pandas dataframe
            x, y = torch.from_numpy(x), torch.from_numpy(y)
            x = (x - self.means) / self.std

            # the weights for the classes should be [30, 63, 7, 0]
            # jets are already labeled as zeros so we don't have to do anything there
            # class #4 is a dead class because those are the padding pixels. It makes it easier to just give them a weight of 0 instead of
            # trying to get rid of them during training.
            y[y == -99] = 1  # backgrounds are labeled as -99, change them to 1
            y[y == -11] = 2  # positrons are labeled as -11, change them to 2
            y[y == 11] = 2   # electrons are labeled as 11, change them to 2 (same as positrons)

            keep_mask = torch.logical_and(y <= 3, y >= 0)
            y = y[keep_mask]
            x = x[keep_mask]

            pad_length = 10000 - x.size(0)
            if pad_length > 0:
                x = torch.cat((x, torch.zeros(pad_length, x.size(1))), dim=0).float()
                y = torch.cat((y, torch.ones(pad_length) * 3)).long()
            else:
                x = x[:10000].float()
                y = y[:10000].long()

            self.data[file] = (x, y)
        return self.data[file]  # type: ignore  # checked for none in the above block
