# Copyright (C) Authors of submission, all rights reserved


from dataclasses import dataclass, field
from typing import ClassVar, List, Type
import numpy as np


@dataclass
class PatchDropoutConfig:
    patch_size: int
    max_drop_prob: float = 0.25
    max_consecutive: int = 5

@dataclass
class CPM:
    config_class: ClassVar[Type] = PatchDropoutConfig
    config: PatchDropoutConfig

    def set_seed(self, seed: int):
        self._rng = np.random.default_rng(seed)

    @property
    def rng(self):
        rng = getattr(self, "_rng", np.random.default_rng())
        self._rng = rng
        return rng

    def return_with_opt_info(self, sample, info, return_info):
        if return_info:
            return sample, info
        else:
            return sample

    def __call__(self, sample: np.ndarray, past_end_idx, return_info=False, **kwargs) -> np.ndarray:
        if (past_end_idx // self.config.patch_size) == 0:
            return self.return_with_opt_info(sample, {}, return_info)

        drop_p = self.rng.uniform(low=0.0, high=self.config.max_drop_prob)
        max_consecutive = min(self.config.max_consecutive, past_end_idx // self.config.patch_size)
        if max_consecutive > 1:
            consecutive = self.rng.integers(low=1, high=max_consecutive)
        else:
            consecutive = 1

        n_ctx_patches = past_end_idx // (self.config.patch_size * consecutive)
        dropped = self.rng.choice([True, False], size=n_ctx_patches, p=[drop_p, 1 - drop_p])
        dropped = np.repeat(dropped, self.config.patch_size * consecutive)
        if (diff := len(sample) - len(dropped)) > 0:
            dropped = np.concatenate((np.zeros(len(sample) - len(dropped), dtype=bool), dropped))
        elif diff < 0:
            dropped = dropped[abs(diff):]
        sample = np.array(sample, dtype=sample.dtype)
        sample[dropped] = np.nan
        return self.return_with_opt_info(sample, {}, return_info)
