import tensorflow as tf
import numpy as np

class BatchCyclingSchedule:
    def __init__(self, real_data, batch_size=16, steps_per_batch=500, init_steps=1):
        self.real_data = real_data 
        self.batch_size = batch_size
        self.steps_per_batch = steps_per_batch
        self.init_steps = init_steps
        self.num_batches = real_data.shape[0] // batch_size
        if self.num_batches == 0:
            raise ValueError("Not enough data for one batch with the given batch_size.")
        self.output_shape = (batch_size,) + real_data.shape[1:]
    
    def __call__(self, step):
        def _get_batch(step_value):
            step_value = int(step_value) 
            if step_value < self.init_steps:
                batch = self.real_data[:self.batch_size]
            else:
                relative_step = step_value - self.init_steps
                batch_index = (relative_step // self.steps_per_batch) % self.num_batches
                start_idx = batch_index * self.batch_size
                end_idx = start_idx + self.batch_size
                batch = self.real_data[start_idx:end_idx]
            return batch.astype(np.float32)
        
        batch = tf.py_function(_get_batch, [step], tf.float32)
        batch.set_shape(self.output_shape)
        return batch

class ConstantSchedule:
    def __init__(self, value=1.0):
        self.value = value

    def __call__(self, step):
        return self.value

class ZeroOneSchedule:
    def __init__(self, threshold_step, init_step=1):
        self.threshold_step = threshold_step
        self.init_step = init_step

    def __call__(self, step):
        if tf.greater(step, self.threshold_step):
            return 1.0
        else:
            return 0.0

class LinearSchedule:
    def __init__(self, max_steps=32 * 100.0, init_step=1):
        self.init_step = tf.cast(init_step, tf.float32)
        self.max_steps = tf.cast(max_steps, tf.float32)

    def __call__(self, step):
        return tf.cast(step, tf.float32) / self.max_steps
    

class RampSchedule:
    def __init__( self, init_step: float, max_steps: float, max_val: float = 1.0):
        
        self.init_step = tf.cast(init_step, tf.float32)
        self.max_steps = tf.cast(max_steps, tf.float32)
        self.max_val = tf.cast(max_val, tf.float32)

    def __call__(self, step):
        step_f = tf.cast(step, tf.float32)
        slope = self.max_val / (self.max_steps - self.init_step + 1e-8)
        raw = (step_f - self.init_step) * slope
        return tf.clip_by_value(raw, clip_value_min=0.0, clip_value_max=self.max_val)
