import torch
import torchaudio
import numpy as np
import random
from pydub import AudioSegment
from io import BytesIO

def _to_numpy(waveform):
    return waveform.numpy().squeeze() if isinstance(waveform, torch.Tensor) else waveform

def _to_tensor(array):
    return torch.from_numpy(array.astype(np.float32)).unsqueeze(0)

# ==== Benign Operations ====
class AudioProcessor:
    @staticmethod
    def benign_compression(waveform, sample_rate,bitrate="128k"):
        waveform = _to_numpy(waveform)
        buffer = BytesIO()
        
        waveform_int16 = (waveform * 32767).clip(-32768, 32767).astype(np.int16)
        AudioSegment(
            waveform_int16.tobytes(),
            frame_rate=sample_rate,
            sample_width=2,
            channels=1
        ).export(buffer, format="mp4", codec="aac", bitrate=bitrate)
        buffer.seek(0)
        decoded = np.array(AudioSegment.from_file(buffer, format="mp4").get_array_of_samples()) / 32768.0
        return _to_tensor(decoded)

    @staticmethod
    def benign_resample(waveform, orig_sr, target_sr=8000):
        # .wav → 8kHz → back to 16kHz
        down = torchaudio.transforms.Resample(orig_sr, target_sr)(waveform)
        up = torchaudio.transforms.Resample(target_sr, orig_sr)(down)
        if up.size(1) > waveform.size(1):
            up = up[:, :waveform.size(1)]
        elif up.size(1) < waveform.size(1):
            pad = waveform.size(1) - up.size(1)
            up = torch.nn.functional.pad(up, (0, pad))
        return up

    @staticmethod
    def benign_reencode(waveform, sample_rate=None):
        # Reencode to int16 and restore (wav → wav)
        waveform = _to_numpy(waveform)
        for _ in range(3):
            buffer = BytesIO()
            waveform_int16 = (waveform * 32767).clip(-32768, 32767).astype(np.int16)
            AudioSegment(
                waveform_int16.tobytes(),
                frame_rate=sample_rate,
                sample_width=2,
                channels=1
            ).export(buffer, format="wav")
            buffer.seek(0)
            waveform = np.array(AudioSegment.from_file(buffer, format="wav").get_array_of_samples()) / 32768.0
        return _to_tensor(waveform)

    @staticmethod
    def benign_noise_suppression(waveform, sr, energy_threshold=0.01, frame_size=400, hop_size=160):
        waveform = waveform.clone()
        x = waveform[0]
        T = x.size(0)

        for start in range(0, T - frame_size + 1, hop_size):
            frame = x[start:start + frame_size]
            energy = torch.sqrt((frame ** 2).mean())
            if energy < energy_threshold:
                x[start:start + frame_size] = 0.0

        return waveform

    # ==== Malicious Operations ====
    @staticmethod
    def malicious_delete(waveform, sample_rate=16000, ratio=0.1):
        x = waveform.clone()
        T = waveform.size(-1)
        del_len = int(T * ratio)

        center_start = int(T * 0.1)
        center_end = int(T * 0.9 - del_len)

        if center_end <= center_start:
            start = (T - del_len) // 2
        else:
            start = random.randint(center_start, center_end)

        return torch.cat([
            x[..., :start],
            x[..., start+del_len:]
        ], dim=-1)
    
    @staticmethod
    def malicious_silence(waveform, sample_rate=16000, ratio=0.1, frame_size=400, hop_size=160, energy_threshold=0.01):
        x = waveform.clone()
        x_flat = x[0]  # [T]
        T = x_flat.size(0)

        frames = x_flat.unfold(0, frame_size, hop_size)  # [num_frames, frame_size]
        energy = torch.sqrt((frames ** 2).mean(dim=1))   # [num_frames]
        voiced_frames = (energy > energy_threshold).nonzero(as_tuple=True)[0]

        if len(voiced_frames) == 0:
            return x  # no speech detected, return unchanged

        start_frame = voiced_frames[random.randint(0, len(voiced_frames)-1)]
        start = start_frame * hop_size
        mute_len = int(T * ratio)
        end = min(start + mute_len, T)

        x[..., start:end] = 0
        return x

    @staticmethod
    def malicious_reorder(waveform, sample_rate=None, num_segments=None):
        if num_segments is None:
            num_segments = random.choice([4, 6, 8])

        T = waveform.shape[-1]
        cut_points = sorted(random.sample(range(1, T), num_segments - 1))
        segment_boundaries = [0] + cut_points + [T]
        indices = list(zip(segment_boundaries[:-1], segment_boundaries[1:]))
        random.shuffle(indices)
        reordered = torch.cat([waveform[:, start:end] for start, end in indices], dim=-1)
        return reordered

    @staticmethod
    def malicious_splice(waveform, sample_rate=None, spliced_waveform=None): 
        start = random.randint(0, waveform.size(-1))
        return torch.cat([waveform[..., :start], spliced_waveform, waveform[..., start:]], dim=-1)


    @staticmethod
    def malicious_substitute(waveform, sample_rate=None, replace_waveform=None, frame_size=400, hop_size=160, energy_threshold=0.01):
        x = waveform.clone()
        x_flat = x[0]
        T = x_flat.size(0)
        sub_len = replace_waveform.size(-1)

        # Energy-based VAD
        frames = x_flat.unfold(0, frame_size, hop_size)
        energy = torch.sqrt((frames ** 2).mean(dim=1))
        voiced_frames = (energy > energy_threshold).nonzero(as_tuple=True)[0]

        if len(voiced_frames) == 0:
            return x

        valid_starts = voiced_frames * hop_size
        valid_starts = valid_starts[valid_starts <= T - sub_len]
        if len(valid_starts) == 0:
            start = (T - sub_len) // 2
        else:
            start = valid_starts[random.randint(0, len(valid_starts)-1)].item()

        return torch.cat([x[..., :start], replace_waveform, x[..., start+sub_len:]], dim=-1)

    @staticmethod
    def malicious_voice_conversion(waveform, sample_rate=16000, pitch_shift=None):
        if pitch_shift is None:
            pitch_shift = random.choice([-4, -3, -2, 2, 3, 4])

        effects = [
            ["pitch", f"{pitch_shift * 100}"],
            ["rate",  f"{sample_rate}"],
        ]
        converted, _ = torchaudio.sox_effects.apply_effects_tensor(
            waveform, sample_rate, effects
        )

        if converted.size(-1) > waveform.size(-1):
            converted = converted[..., :waveform.size(-1)]
        elif converted.size(-1) < waveform.size(-1):
            pad = waveform.size(-1) - converted.size(-1)
            converted = torch.nn.functional.pad(converted, (0, pad))

        return converted