import pickle
import sys
sys.path.append("..")
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import torch
from dataset import PHONE_DEF
import re 
from g2p_en import G2p
import numpy as np
from model import GRUDecoder, SimpleGRUDecoder, LightningGRUDecoder
import time
import numpy as np
from edit_distance import SequenceMatcher
import tqdm
import pytorch_lightning as pl
import jiwer
import nltk
from nltk.corpus import cmudict
from pytorch_lightning.loggers import WandbLogger
import wandb
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import copy
from difflib import get_close_matches
from phoneme2text import Phoneme2TextModel, PhonemeTokenizer, VanillaTransformerEncoder
from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer
import pandas as pd
from torchaudio.models.decoder import ctc_decoder
import string


tokens = ["<blank>"] + PHONE_DEF + [" "]
decoder = ctc_decoder(tokens= tokens,   
                      lexicon=None,  
                      blank_token = '<blank>', 
                      sil_token = ' ',
                      )


def decode_ctc_output(logits):
    """
    Converts model logits to predicted phoneme sequences.
    - Removes repeated phonemes.
    - Removes blank tokens (0).
    """

    predictions = torch.argmax(logits, dim=-1)  # Get most probable phoneme indices
    predictions = [torch.unique_consecutive(seq[seq != 0]).cpu().numpy() for seq in predictions]  # Remove blanks
    return predictions


def compute_accuracy(preds, targets):
    

    accs= []
    for pred, target in zip(preds, targets):
        
        #truncate to the length of the shortest sequence
        min_len = min(len(pred), len(target))


        pred = pred[:min_len]
        target = target[:min_len]

        equal_inference = (pred == target)
        acc = np.sum(equal_inference)/ len(pred)
        accs.append(acc)

    return np.mean(accs)
   
