import torch, math

class DynThresh:
    Modes = [
        "Constant",
        "Linear Down",
        "Cosine Down",
        "Half Cosine Down",
        "Linear Up",
        "Cosine Up",
        "Half Cosine Up",
        "Power Up",
        "Power Down",
        "Linear Repeating",
        "Cosine Repeating",
        "Sawtooth",
    ]
    Startpoints = ["MEAN", "ZERO"]
    Variabilities = ["AD", "STD"]

    def __init__(
        self,
        mimic_scale,
        threshold_percentile,
        mimic_mode,
        mimic_scale_min,
        cfg_mode,
        cfg_scale_min,
        sched_val,
        experiment_mode,
        max_steps,
        separate_feature_channels,
        scaling_startpoint,
        variability_measure,
        interpolate_phi,
    ):
        self.mimic_scale = mimic_scale
        self.threshold_percentile = threshold_percentile
        self.mimic_mode = mimic_mode
        self.cfg_mode = cfg_mode
        self.max_steps = max_steps
        self.cfg_scale_min = cfg_scale_min
        self.mimic_scale_min = mimic_scale_min
        self.experiment_mode = experiment_mode
        self.sched_val = sched_val
        self.sep_feat_channels = separate_feature_channels
        self.scaling_startpoint = scaling_startpoint
        self.variability_measure = variability_measure
        self.interpolate_phi = interpolate_phi

    def interpret_scale(self, scale, mode, min, step):
        scale -= min
        max = self.max_steps - 1
        frac = step / max
        if mode == "Constant":
            pass
        elif mode == "Linear Down":
            scale *= 1.0 - frac
        elif mode == "Half Cosine Down":
            scale *= math.cos(frac)
        elif mode == "Cosine Down":
            scale *= math.cos(frac * 1.5707)
        elif mode == "Linear Up":
            scale *= frac
        elif mode == "Half Cosine Up":
            scale *= 1.0 - math.cos(frac)
        elif mode == "Cosine Up":
            scale *= 1.0 - math.cos(frac * 1.5707)
        elif mode == "Power Up":
            scale *= math.pow(frac, self.sched_val)
        elif mode == "Power Down":
            scale *= 1.0 - math.pow(frac, self.sched_val)
        elif mode == "Linear Repeating":
            portion = (frac * self.sched_val) % 1.0
            scale *= (0.5 - portion) * 2 if portion < 0.5 else (portion - 0.5) * 2
        elif mode == "Cosine Repeating":
            scale *= math.cos(frac * 6.28318 * self.sched_val) * 0.5 + 0.5
        elif mode == "Sawtooth":
            scale *= (frac * self.sched_val) % 1.0
        scale += min
        return scale

    def dynthresh(self, cond, uncond, cfg_scale, weights, step):
        mimic_scale = self.interpret_scale(
            self.mimic_scale, self.mimic_mode, self.mimic_scale_min, step
        )
        cfg_scale = self.interpret_scale(
            cfg_scale, self.cfg_mode, self.cfg_scale_min, step
        )

        conds_per_batch = cond.shape[0] / uncond.shape[0]
        assert conds_per_batch == int(
            conds_per_batch
        ), "Expected # of conds per batch to be constant across batches"
        cond_stacked = cond.reshape((-1, int(conds_per_batch)) + uncond.shape[1:])

        diff = cond_stacked - uncond.unsqueeze(1)
        if weights is not None:
            diff = diff * weights
        relative = diff.sum(1)

        mim_target = uncond + relative * mimic_scale
        cfg_target = uncond + relative * cfg_scale

        mim_flattened = mim_target.flatten(3)  # (bs, c, f*h*w)
        cfg_flattened = cfg_target.flatten(3)
        mim_means = mim_flattened.mean(dim=3).unsqueeze(3)
        cfg_means = cfg_flattened.mean(dim=3).unsqueeze(3)
        mim_centered = mim_flattened - mim_means
        cfg_centered = cfg_flattened - cfg_means

        if self.sep_feat_channels:
            if self.variability_measure == "STD":
                mim_scaleref = mim_centered.std(dim=3).unsqueeze(3)
                cfg_scaleref = cfg_centered.std(dim=3).unsqueeze(3)
            else:  # 'AD'
                mim_scaleref = mim_centered.abs().max(dim=3).values.unsqueeze(3)
                cfg_scaleref = torch.quantile(
                    cfg_centered.abs(), self.threshold_percentile, dim=3
                ).unsqueeze(3)
        else:
            if self.variability_measure == "STD":
                mim_scaleref = mim_centered.std()
                cfg_scaleref = cfg_centered.std()
            else:  # 'AD'
                mim_scaleref = mim_centered.abs().max()
                cfg_scaleref = torch.quantile(
                    cfg_centered.abs(), self.threshold_percentile
                )

        if self.scaling_startpoint == "ZERO":
            scaling_factor = mim_scaleref / cfg_scaleref
            result = cfg_flattened * scaling_factor
        else:  # 'MEAN'
            if self.variability_measure == "STD":
                cfg_renormalized = (cfg_centered / cfg_scaleref) * mim_scaleref
            else:  # 'AD'
                max_scaleref = torch.maximum(mim_scaleref, cfg_scaleref)
                cfg_clamped = cfg_centered.clamp(-max_scaleref, max_scaleref)
                cfg_renormalized = (cfg_clamped / max_scaleref) * mim_scaleref

            result = cfg_renormalized + cfg_means

        actual_res = result.unflatten(3, mim_target.shape[3:])

        if self.interpolate_phi != 1.0:
            actual_res = actual_res * self.interpolate_phi + cfg_target * (
                1.0 - self.interpolate_phi
            )

        # Experiment modes (1, 2, 3) can remain the same
        # ...

        return actual_res
