from datasets import load_dataset
import re 
from g2p_en import G2p
import numpy as np
from transformers import GPT2Tokenizer
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
import wandb
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import re
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2LMHeadModel, GPT2Config
import torch.nn as nn
import pytorch_lightning as pl
from transformers import GPT2LMHeadModel
import wandb





class PhonemeTokenizer:
    def __init__(self, phonemes):
        """
        Args:
            phonemes: List of unique phoneme strings to build the vocabulary.
        """
        # Define special tokens
        self.pad_token = "[PAD]"
        self.unk_token = "[UNK]"
        self.bos_token = "[BOS]"
        self.eos_token = "[EOS]"

        # Initialize vocabulary with special tokens
        self.vocab = {
            self.pad_token: 0,
            self.unk_token: 1,
            self.bos_token: 2,
            self.eos_token: 3
        }

        self.pad_token_id = self.vocab[self.pad_token]
        self.unk_token_id = self.vocab[self.unk_token]
        self.bos_token_id = self.vocab[self.bos_token]

        # Add phonemes starting from index 4
        for i, p in enumerate(phonemes, start=4):
            self.vocab[p] = i

        # Reverse mapping: ID -> phoneme
        self.id2phoneme = {v: k for k, v in self.vocab.items()}

    def __len__(self):
        return len(self.vocab)

    def encode(self, phoneme_seq, max_length=None, add_special_tokens=True):
        """
        Convert a list of phoneme strings into a list of token IDs.

        Args:
            phoneme_seq: List of phoneme strings.
            max_length: Optional max length for padding/truncation.
            add_special_tokens: If True, adds BOS and EOS tokens.

        Returns:
            List of token IDs.
        """
        tokens = phoneme_seq
        if add_special_tokens:
            tokens = [self.bos_token] + tokens + [self.eos_token]

        encoded = [self.vocab.get(p, self.vocab[self.unk_token]) for p in tokens]

        if max_length is not None:
            encoded = encoded[:max_length]
            encoded += [self.vocab[self.pad_token]] * (max_length - len(encoded))

        return encoded

    def batch_encode(self, phoneme_seqs, max_length=None, add_special_tokens=True):
        """
        Convert a batch of phoneme sequences into padded token ID tensors.

        Args:
            phoneme_seqs: List of list of phoneme strings.
            max_length: Optional fixed length for all sequences.
            add_special_tokens: Add BOS/EOS to each sequence.

        Returns:
            torch.Tensor of shape (batch_size, max_seq_len)
        """
        encoded_seqs = [
            self.encode(seq, add_special_tokens=add_special_tokens) for seq in phoneme_seqs
        ]

        max_seq_len = max(len(seq) for seq in encoded_seqs) if max_length is None else max_length

        padded_seqs = [
            seq[:max_seq_len] + [self.vocab[self.pad_token]] * (max_seq_len - len(seq))
            for seq in encoded_seqs
        ]

        return torch.tensor(padded_seqs, dtype=torch.long)

    def decode(self, id_seq, skip_special_tokens=True):
        """
        Convert a sequence of IDs back to phonemes.

        Args:
            id_seq: Tensor or list of token IDs.
            skip_special_tokens: Whether to skip [PAD], [BOS], [EOS].

        Returns:
            List of phoneme strings.
        """
        if isinstance(id_seq, torch.Tensor):
            id_seq = id_seq.tolist()

        if skip_special_tokens:
            special_ids = {
                self.vocab[self.pad_token],
                self.vocab[self.bos_token],
                self.vocab[self.eos_token],
            }
            return [
                self.id2phoneme[i]
                for i in id_seq
                if i not in special_ids
            ]
        else:
            return [self.id2phoneme[i] for i in id_seq]

    def batch_decode(self, id_seqs, skip_special_tokens=True):
        """
        Convert a batch of sequences of IDs back to phoneme lists.

        Args:
            id_seqs: List of sequences (or tensor of shape (batch, seq)).

        Returns:
            List of lists of phoneme strings.
        """
        if isinstance(id_seqs, torch.Tensor):
            id_seqs = id_seqs.tolist()

        return [self.decode(seq, skip_special_tokens=skip_special_tokens) for seq in id_seqs]

    def get_vocab_size(self):
        return len(self.vocab)

    def token_to_id(self, token):
        return self.vocab.get(token, self.vocab[self.unk_token])

    def id_to_token(self, id_):
        return self.id2phoneme.get(id_, self.unk_token)



class LightningGPT2PhonemeModel(pl.LightningModule):
    def __init__(self, gpt2_model, tokenizer, learning_rate=5e-5, weight_decay=1e-5, max_length=64):
        super().__init__()
        self.model = gpt2_model  # GPT2LMHeadModel
        self.tokenizer = tokenizer
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.max_length = max_length

    def forward(self, phoneme_seqs):
        """
        phoneme_seqs: List of phoneme strings (e.g., ["B D AH", "P EY T"])
        or lists of phoneme tokens (e.g., [["B", "D", "AH"], ["P", "EY", "T"]])
        """
        if isinstance(phoneme_seqs[0], str):
            phoneme_seqs = [seq.split() for seq in phoneme_seqs]

        input_ids = self.tokenizer.batch_encode(
            phoneme_seqs, max_length=self.max_length, add_special_tokens=True
        ).to(self.device)

        attention_mask = (input_ids != self.tokenizer.pad_token_id).long()

        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
        return outputs.loss, outputs.logits

    def training_step(self, batch, batch_idx):
        phoneme_seqs, _ = batch
        loss, _ = self(phoneme_seqs)
        self.log("train_loss", loss, prog_bar=True, logger=True, batch_size=len(phoneme_seqs))
        return loss

    def validation_step(self, batch, batch_idx):
        phoneme_seqs, _ = batch
        loss, _ = self(phoneme_seqs)
        self.log("val_loss", loss, prog_bar=True, logger=True, batch_size=len(phoneme_seqs))

        if batch_idx == 0:
            self.log_validation_examples(phoneme_seqs)
        return loss

    def log_validation_examples(self, phoneme_seqs):
        generated = self.generate(phoneme_seqs)
        original = [" ".join(seq if isinstance(seq, list) else seq.split()) for seq in phoneme_seqs]

        table = wandb.Table(columns=["Input", "Generated"])
        for i in range(min(5, len(original))):
            table.add_data(original[i], generated[i])
        wandb.log({"Validation Examples": table})

    def generate(self, phoneme_seqs, max_length=40):
        if isinstance(phoneme_seqs[0], str):
            phoneme_seqs = [seq.split() for seq in phoneme_seqs]

        input_ids = self.tokenizer.batch_encode(
            phoneme_seqs, max_length=self.max_length, add_special_tokens=True
        ).to(self.device)

        attention_mask = (input_ids != self.tokenizer.pad_token_id).long()

        generated_ids = self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_length,
            num_beams=5,
            early_stopping=True
        )
        return self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
