from datasets import load_dataset, Audio, DatasetDict
import torch
import audiomentations as am
import torchaudio
import numpy as np
import random
from transformers import (
    WhisperForConditionalGeneration,
    WhisperProcessor,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    WhisperFeatureExtractor,
    WhisperTokenizer
)
import matplotlib.pyplot as plt
import pandas as pd
import evaluate
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torch.nn.utils.prune as prune
from collections import defaultdict
import torch.nn as nn



# Initialize a DatasetDict to hold train/test splits
common_voice = DatasetDict()

# Load Hindi ("hi") data from Common Voice 11.0
# Combine train + validation splits, then take a subset
common_voice["train"]  = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="train+validation", trust_remote_code=True)
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test", trust_remote_code=True)

common_voice = common_voice.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"])

# Resample audio to 16kHz (Whisper's expected input)
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))

# Define augmentation pipeline
augmentation_pipeline = am.Compose([
    am.AddGaussianNoise(p=0.3, min_amplitude=0.001, max_amplitude=0.015),
    am.PitchShift(p=0.4, min_semitones=-1, max_semitones=1),
    am.TimeStretch(p=0.3, min_rate=0.9, max_rate=1.1),
])

#Load whisper feature extractor and tokenizer
from transformers import WhisperFeatureExtractor, WhisperTokenizer

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")

# Load processor (handles feature extraction + tokenization)
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")

# Load pretrained small Whisper model (adjust size as needed)
from transformers import WhisperForConditionalGeneration

model_path = "openai/whisper-small"
model = WhisperForConditionalGeneration.from_pretrained(model_path)
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
print(model)

model.generation_config.language = "hindi"
model.generation_config.task = "transcribe"
model.config.apply_spec_augment = True

model.generation_config.forced_decoder_ids = None
model.config.mask_time_prob = 0.05  # Default: 0.05
model.config.mask_feature_prob = 0.03  # Default: 0.03
model.config.mask_feature_length = 10  # Default: 10

def prepare_dataset(batch):
    audio = batch["audio"]
    audio_array = audio["array"]
    sr = audio["sampling_rate"]

    # Apply augmentations with 50% probability
    if random.random() < 0.5:  # 50% chance of augmentation
        audio_array = augmentation_pipeline(
            samples=audio_array,
            sample_rate=sr
        )

    # Compute features
    batch["input_features"] = feature_extractor(
        audio_array,
        sampling_rate=sr
    ).input_features[0]

    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

# Apply preprocessing to all splits
common_voice = common_voice.map(
    prepare_dataset,
    remove_columns=common_voice.column_names["train"], # Drop original columns
    num_proc=2
)

import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

import re
import unicodedata

import re
import unicodedata



import evaluate
metric = evaluate.load("wer")

# === UPDATED: Apply Whisper normalization for research/benchmarking ===
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)


    norm_pred = [tokenizer.normalize(pred) for pred in pred_str]
    norm_ref=[tokenizer.normalize(pred) for pred in label_str]

    # wer = 100 * metric.compute(predictions=pred_str_normalized, references=label_str_normalized)
    wer = 100 * metric.compute(predictions=norm_pred, references=norm_ref)

    return {"wer": wer}

# Updated Training Arguments with Improvements
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-small-vanilla-plrs-run3",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    learning_rate=1e-5,  # Slightly lower base LR
    warmup_steps=1000,  # More warmup steps
    max_steps=5000,  #  training duration
    gradient_checkpointing=True,
    fp16=True,
    eval_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=50,
    report_to="none",
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
    # lr_scheduler_type="cosine",
    warmup_ratio=0.2,  # Longer warmup
    weight_decay=0.01,  # Added weight decay
    max_grad_norm=1.0,  # Gradient clipping
    optim="adamw_torch",  # Explicit optimizer
    group_by_length=False,  # # Disable length grouping to avoid input_ids error
    remove_unused_columns=False,  # Important for augmentation
    dataloader_num_workers=4,     # Better for augmentation
)

# Modify model config for better regularization
model.config.dropout = 0.1  # Increased from default 0.0
model.config.activation_dropout = 0.1  # Increased from default 0.0
model.config.attention_dropout = 0.1  # Increased from default 0.0


# ======================================================================================
# ======================================================================================
# ======================  L1, L2-LOSS REGULARIZATION  ==================================
# ======================================================================================
# ======================================================================================

import torch.nn.functional as F



import numpy as np
import warnings
from torch.optim.lr_scheduler import _LRScheduler

class UniformNoisyLR(_LRScheduler):
    def __init__(self, optimizer, max_lr, min_lr, offset, verbose=False, last_epoch=-1):
        self.min_lr = min_lr
        self.offset = offset
        self.max_lr = max_lr
        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn(
                "To get the last learning rate computed by the scheduler, "
                "please use `get_last_lr()`.",
                UserWarning,
            )
        return [
            np.random.uniform(self.min_lr - self.offset, self.max_lr - self.offset) + self.offset
            for _ in self.base_lrs
        ]


class CustomTrainer(Seq2SeqTrainer):
    def create_scheduler(self, num_training_steps: int, optimizer=None):
        optimizer = self.optimizer if optimizer is None else optimizer
        self.lr_scheduler = UniformNoisyLR(
            optimizer,
            max_lr=1e-5,
            min_lr=1e-6,
            offset=1e-5
        )
        return self.lr_scheduler
trainer = CustomTrainer(
     model=model,
     args=training_args,
     train_dataset=common_voice["train"],
     eval_dataset=common_voice["test"],
     data_collator=data_collator,
     compute_metrics=compute_metrics,
 )





trainer.train()
exit()

