import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
import wandb
from transformers import GPT2LMHeadModel, GPT2Config
import math

######### TOKENIZER ####################



import re
import random
from torch.utils.data import Dataset
from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig

import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
import wandb
from transformers import GPT2LMHeadModel, GPT2Config

class Phoneme2TextModel(pl.LightningModule):
    def __init__(self, learning_rate=5e-5, weight_decay=1e-5):
        super().__init__()
        self.model =  BartForConditionalGeneration.from_pretrained("facebook/bart-base")
        self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")

        self.learning_rate = learning_rate
        self.weight_decay = weight_decay



    def forward(self, phoneme_seqs, text_seqs):
                # Training example
        inputs = self.tokenizer(phoneme_seqs, padding=True, return_tensors="pt")
        labels = self.tokenizer(text_seqs,padding=True, return_tensors="pt")

        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        labels = {k: v.to(self.device) for k, v in labels.items()}

        outputs = self.model(**inputs, labels=labels["input_ids"])

        return outputs.loss, outputs.logits

    def training_step(self, batch, batch_idx):
        """
        Training loop: Forward pass + Compute loss + Log progress
        """
        phonemes, texts = batch
        batch_size = len(texts)  # Number of samples in the batch

        loss, _ = self.forward(phonemes, texts)
        
        # Log loss to progress bar
        self.log("train_loss", loss, prog_bar=True, logger=True, batch_size=batch_size)
        return loss

    def validation_step(self, batch, batch_idx):
        """
        Validation loop: Compute loss + Log WandB examples
        """
        phonemes, texts = batch
        loss, _ = self.forward(phonemes, texts)

        batch_size = len(texts)  # Number of samples in the batch

        # Log loss to progress bar
        self.log("val_loss", loss, prog_bar=True, logger=True, batch_size=batch_size)

        # Log examples to WandB every few validation steps
        if batch_idx == 0:
            self.log_validation_examples(phonemes, texts)

        return loss

    def log_validation_examples(self, phonemes, texts):
        """
        Logs phoneme input, ground truth text, and reconstructed text in WandB.
        """
        generated_texts = self.generate_text(phonemes)

        # Prepare the examples in the correct format for WandB Table
        examples = []
        for i in range(min(5, len(texts))):
            phoneme_str = phonemes[i]
            examples.append([phoneme_str, texts[i], generated_texts[i]])

        # Log examples in WandB
        wandb.log({"Validation Examples": wandb.Table(columns=["Phonemes", "Ground Truth", "Reconstructed"], data=examples)})

    def configure_optimizers(self):
        """
        Define optimizer (Adam)
        """
        return optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)

    def generate_text(self, phoneme_seqs,num_beams=5, max_length=40, **kwargs):

        #tokenizer
        inputs = self.tokenizer(phoneme_seqs, padding=True, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}


        outputs = self.model.generate(**inputs, num_beams=num_beams, max_length=max_length, **kwargs)
        decoded_sentences = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        return decoded_sentences
    

class NoisyBookCorpusPhonemeDataset(Dataset):
    def __init__(
        self,
        dataset,
        g2p,
        addInterWordSymbol=True,
        apply_augmentation=False,
        corruption_prob=0.5,
        error_rate=0.2,
        valid_phonemes=None,
        collapse_to_string=True,
    ):
        """
        Args:
            dataset (list of str): List of text sentences.
            g2p (callable): A g2p function/class instance (e.g. g2p_en.G2p())
                            that converts a string into a list of phonemes.
            addInterWordSymbol (bool): Whether to insert 'SIL' at word boundaries/spaces.
            apply_augmentation (bool): If True, randomly corrupt some phoneme sequences.
            corruption_prob (float): Probability with which a sample’s phoneme sequence is corrupted.
            error_rate (float): For each token, probability of applying a corruption operation.
            valid_phonemes (list or None): List of valid phonemes to use for substitution/insertion.
                                           If None, the valid phonemes will be taken from each sequence.
        """
        self.dataset = dataset
        self.g2p = g2p
        self.addInterWordSymbol = addInterWordSymbol
        self.apply_augmentation = apply_augmentation
        self.corruption_prob = corruption_prob
        self.error_rate = error_rate
        self.valid_phonemes = valid_phonemes
        self.collapse_to_string = collapse_to_string

    def corrupt_phoneme_sequence(self, seq, error_rate):
        """
        Corrupt a phoneme sequence by randomly applying one of:
          - substitution: Replace the token with another random phoneme.
          - deletion: Remove the token.
          - insertion: Insert a random phoneme before the current token.
        The decision to corrupt each token is governed by error_rate.
        """
        # Determine the set of phonemes to use for augmentation.
        # If valid_phonemes is not provided, use the unique tokens from the current sequence.
        if self.valid_phonemes is None:
            phoneme_set = list(set(seq))
        else:
            phoneme_set = self.valid_phonemes

        corrupted_seq = []
        for token in seq:
            if random.random() < error_rate:
                op = random.choice(["substitute", "delete", "insert"])
                if op == "substitute":
                    # Substitute with a random phoneme (excluding the current token)
                    alternatives = [p for p in phoneme_set if p != token]
                    if alternatives:
                        corrupted_seq.append(random.choice(alternatives))
                    else:
                        corrupted_seq.append(token)
                elif op == "delete":
                    # Delete this token (skip appending)
                    continue
                elif op == "insert":
                    # Insert a random phoneme before the current token, then append the token.
                    corrupted_seq.append(random.choice(phoneme_set))
                    corrupted_seq.append(token)
            else:
                corrupted_seq.append(token)
        return corrupted_seq

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        """
        Converts text to phonemes and returns (phonemes, text).
        The text is cleaned by removing punctuation and quotes, stripping leading/trailing spaces,
        and collapsing multiple spaces into one.
        Optionally, the phoneme sequence is corrupted according to the augmentation parameters.
        """
        text = self.dataset[idx]

        # --- Clean the text ---
        text = text.strip()                          # Remove leading/trailing whitespace
        text = text.lower()                          # Convert to lower case
        text = re.sub(r'[^\w\s]', '', text)           # Remove punctuation and quotes
        text = re.sub(r'\s+', ' ', text)              # Collapse multiple spaces into one
        text = text.strip()                          # Remove leading/trailing whitespace

        # --- Convert text to phonemes using g2p ---
        phonemes = []
        raw_tokens = self.g2p(text)  # Convert text to phoneme tokens

        for token in raw_tokens:
            if self.addInterWordSymbol and token == ' ':
                phonemes.append(' ')
            # Remove stress markers (e.g., numbers) from phonemes.
            p = re.sub(r'[0-9]', '', token)
            if re.match(r'^[A-Z]+$', p):  # Only keep valid phoneme tokens (uppercase letters)
                phonemes.append(p)

        if self.addInterWordSymbol:
            phonemes.append(' ')

        # --- Optionally apply noise augmentation ---
        if self.apply_augmentation and random.random() < self.corruption_prob:
            phonemes = self.corrupt_phoneme_sequence(phonemes, self.error_rate)

        if self.collapse_to_string:
            #collapse to string
            phonemes = ' '.join(phonemes)

        return phonemes, text

# 🔹 Custom collate function to handle variable-length phoneme sequences
def collate_fn(batch):
    """
    Pads phoneme lists to the maximum length in the batch.
    """
    phoneme_seqs, text_seqs = zip(*batch)  # Unpack batch elements

    # Find the maximum length of phoneme sequences in the batch.
    max_phoneme_length = max(len(seq) for seq in phoneme_seqs)

    # Pad all sequences with 'SIL' to reach the maximum length.
    padded_phonemes = [
        seq + [" "] * (max_phoneme_length - len(seq)) for seq in phoneme_seqs
    ]

    return padded_phonemes, list(text_seqs)






