from __future__ import annotations

import numpy as np

from numpy.typing import ArrayLike
from dataclasses import dataclass

def _get_split_indices(num_samples, sections):
    sections = np.cumsum(np.atleast_1d(sections))
    
    if sections[-1] <= 1:
        sections = np.round(num_samples*sections).astype(int)
    
    assert sections[-1] == num_samples
    
    return sections[:-1]


@dataclass
class Splitter:
    _order: ArrayLike
    
    @property
    def num_samples(self):
        return self._order.size
    
    @property
    def index(self):
        return self._order
    
    def split(self, indices_or_fractions, return_splitter=False) -> tuple[ArrayLike] | tuple[Splitter]:
        indices_or_fractions = _get_split_indices(self.num_samples, indices_or_fractions)
        splits = np.split(self._order, indices_or_fractions)
        if not return_splitter:
            return tuple(splits)
        return tuple(Splitter(o) for o in splits)
    
    def cv(self, num_folds, fold, return_splitter=False):
        assert fold < num_folds

        bins = np.linspace(0, self.num_samples, num_folds + 1)
        bins = np.round(bins).astype(int)[::-1]
        
        test_mask = slice(bins[fold+1], bins[fold])
        train_mask = slice(bins[fold], None), slice(bins[fold+1])
        
        train = np.concatenate([self._order[m] for m in train_mask]) 
        splits = train, self._order[test_mask]
        if not return_splitter:
            return splits
        return tuple(Splitter(o) for o in splits)
    
    def circshift(self, num_folds, fold, return_splitter=False):
        assert fold < num_folds

        bins = np.linspace(0, self.num_samples, num_folds + 1)
        bins = np.round(bins).astype(int)
                
        order = np.roll(self._order, bins[fold])
        if not return_splitter:
            return order
        return Splitter(order)
    
    @classmethod
    def from_shuffle(cls, num_samples, seed=None):
        prng = np.random.default_rng(seed)
        order = np.argsort(prng.random(num_samples))
        return cls(order)
        
    @classmethod
    def from_linear(cls, num_samples):
        order = np.arange(stop=num_samples, dtype=np.int_)
        return cls(order)
    


