from abc import abstractmethod, ABC
from typing import Tuple

from BACKEND import np, cp, to_gpu, to_cpu

from pathlib import Path

from datasets.normalise_data import zero_one
from datasets.slice_data import slice_data

datasets = Path(__file__).resolve().parent

OL = {
    "solar": 168,
    "electricity": 168,
    "metr-la": 12,
    "pems-bay": 12
}

SPLITS = {
    "solar": (0.6, 0.2, 0.2),
    "electricity": (0.6, 0.2, 0.2),
    "metr-la": (0.7, 0.2, 0.1),
    "pems-bay": (0.7, 0.2, 0.1)
}


class SubsetIterator:
    def __init__(self, dataset, batch_size, target, x, y, neuron_idx, transfer_func=to_gpu, load_n_batches: int = 1):
        if target == "train":
            self.offset = 0
            self.n_samples = dataset.train_size
        elif target == "val":
            self.offset = dataset.train_size
            self.n_samples = dataset.val_size
        elif target == "test":
            self.offset = dataset.train_size + dataset.val_size
            self.n_samples = dataset.test_size
        else:
            raise ValueError(f"Invalid target: {target}")

        # batch size and number of batches
        self.batch_size = self.n_samples if batch_size == -1 else batch_size
        # ensure integer number of batches
        self.n_batches = int(cp.ceil(self.n_samples / self.batch_size))

        # what to return
        self.x = x
        self.y = y
        self.neuron_idx = neuron_idx

        # keep reference to the raw arrays
        self.x_cpu = dataset.x_cpu
        self.y_cpu = dataset.y_cpu

        # Transfer to cupy array or PyTorch Tensor
        self.transfer_func = transfer_func

        # how many batches to bundle into one transfer (meta-batch)
        if load_n_batches < 1:
            raise ValueError("load_n_batches must be >= 1")
        self.load_n_batches = int(load_n_batches)

        # iteration state
        self._k = 0  # number of batches already returned (global batch index)
        # meta-batch state:
        # _meta_loaded holds dict {'x': ..., 'y': ..., 'len': int, 'n_inner_batches': int}
        # the transferred arrays are device arrays (whatever transfer_func returns)
        self._meta_loaded = None
        self._meta_inner_idx = 0  # index inside current meta (0 .. n_inner_batches-1)

        self.mean = dataset.get_mean(target)

    def __iter__(self):
        return self

    def __next__(self):
        if self._k >= self.n_batches:
            # cleanup any remaining references
            self._meta_loaded = None
            raise StopIteration

        # if batch_size == all samples, just transfer whole slice (no meta logic needed)
        if self.batch_size == self.n_samples:
            start = int(self.offset)
            end = int(self.offset + self.n_samples)
            out = []
            if self.x:
                x_batch = self.transfer_func(self.x_cpu[start:end])
                out.append(x_batch)
            if self.y:
                y_slice = self.y_cpu[start:end]
                if self.neuron_idx is not None:
                    y_slice = y_slice[:, self.neuron_idx]
                y_batch = self.transfer_func(y_slice)
                out.append(y_batch)
            self._k = self.n_batches
            return tuple(out) if len(out) > 1 else out[0]

        # Ensure meta-batch is loaded for the current global batch index self._k
        if self._meta_loaded is None:
            # compute meta range in sample indices for the new meta-batch
            # meta_batch_size in samples = self.batch_size * self.load_n_batches
            meta_batch_size_samples = int(self.batch_size * self.load_n_batches)

            meta_start = int(self.offset + self._k * self.batch_size)
            meta_end = int(min(self.offset + self.n_samples, meta_start + meta_batch_size_samples))
            meta_len = meta_end - meta_start
            # number of inner batches inside this meta
            n_inner_batches = int(cp.ceil(meta_len / self.batch_size))

            meta = {}
            if self.x:
                # transfer the whole meta-x slice at once
                meta_x_cpu = self.x_cpu[meta_start:meta_end]
                meta['x'] = self.transfer_func(meta_x_cpu)
            if self.y:
                meta_y_cpu = self.y_cpu[meta_start:meta_end]
                if self.neuron_idx is not None:
                    meta_y_cpu = meta_y_cpu[:, self.neuron_idx]
                meta['y'] = self.transfer_func(meta_y_cpu)

            meta['meta_start'] = meta_start
            meta['meta_end'] = meta_end
            meta['len'] = meta_len
            meta['n_inner_batches'] = n_inner_batches

            self._meta_loaded = meta
            self._meta_inner_idx = 0

        # Now serve the inner batch within the meta-batch
        inner = self._meta_inner_idx
        start_in_meta = int(inner * self.batch_size)
        end_in_meta = int(min(start_in_meta + self.batch_size, self._meta_loaded['len']))

        out = []
        if self.x:
            x_meta = self._meta_loaded['x']
            # slice the transferred device array
            x_batch = x_meta[start_in_meta:end_in_meta]
            out.append(x_batch)
        if self.y:
            y_meta = self._meta_loaded['y']
            y_batch = y_meta[start_in_meta:end_in_meta]
            out.append(y_batch)

        # advance indices
        self._meta_inner_idx += 1
        self._k += 1

        # if we consumed the whole meta-batch, release it so next __next__ will load the next meta.
        if self._meta_inner_idx >= self._meta_loaded['n_inner_batches']:
            # drop references to device arrays to free memory
            self._meta_loaded = None
            self._meta_inner_idx = 0

        return tuple(out) if len(out) > 1 else out[0]


class IterableDataLoader(ABC):
    @abstractmethod
    def __init__(self):
        pass

    def iterate(self, batch_size=-1, target="test", x=True, y=True, neuron_idx=None, transfer_func=to_gpu, load_n_batches: int = 1):
        return SubsetIterator(self, batch_size, target, x, y, neuron_idx, transfer_func=transfer_func, load_n_batches=load_n_batches)

    def get_first_batch(self, batch_size, target, x=True, y=True, neuron_idx=None, transfer_func=to_gpu, load_n_batches: int = 1):
        for e in self.iterate(batch_size, target=target, x=x, y=y, neuron_idx=neuron_idx, transfer_func=transfer_func, load_n_batches=load_n_batches):
            return e


class DataLoader(IterableDataLoader):
    def __init__(self, dataset:str, prediction_horizon, shuffle=True, seed=42, dtype=np.float32, force_recreate=False,
                 normaliser=zero_one, reduced=False):
        self.OL = OL[dataset]
        self.H = prediction_horizon
        self.x_cpu, self.y_cpu = slice_data(dataset, self.OL, prediction_horizon, overlap=self.OL - 1, force_recreate=force_recreate, dtype=dtype, normaliser=normaliser, reduced=reduced)

        if shuffle:
            rng = np.random.default_rng(seed)
            shuffle_idcs = rng.permutation(self.x_cpu.shape[0])

            self.x_cpu = self.x_cpu[shuffle_idcs]
            self.y_cpu = self.y_cpu[shuffle_idcs]

        n_tot = self.x_cpu.shape[0]

        self.split = SPLITS[dataset]
        self.train_size = int(cp.floor(self.split[0] * n_tot))
        self.val_size = int(cp.ceil(self.split[1] * n_tot))
        self.test_size = self.x_cpu.shape[0] - (self.train_size + self.val_size)

    def clear(self):
        del self.x_cpu
        del self.y_cpu

    def get_ol_t(self):
        return self.OL, self.H + self.OL

    def get_n_batches(self, batch_size, target="train"):
        if target == "train":
            return self.train_size / batch_size
        elif target == "val":
            return self.val_size / batch_size
        elif target == "test":
            return self.test_size / batch_size
        else:
            raise ValueError(f"Invalid target: {target}")

    def get_mean(self, target="train"):
        if target == "train":
            return self.x_cpu[:self.train_size].mean()
        elif target == "val":
            return self.x_cpu[self.train_size:self.train_size + self.val_size].mean()
        elif target == "test":
            return self.x_cpu[self.train_size + self.val_size:].mean()
        else:
            raise ValueError(f"Invalid target: {target}")

    def get_dim_t(self):
        return self.x_cpu.shape[1:]




