import torch
import os
print(torch.__version__)
import os

import pickle
import sys
sys.path.append('../')
import  utils.lmDecoderUtils as lmDecoderUtils
import numpy as np
import os 
import time
import tqdm
import pandas as pd
import torch
# results_dir = "../results/mfcc_sm_gru_ctc_LONGRUN/"
# results_dir ="../results/gru_ctc_diphones_dualhead"


#### DATA
COMPETITION_DATA = True
PRELOAD_NBEST = False
DEVICE = "cuda:6"
# results_dir = "../competition_evaluation/results/mfcc_sm_gru_ctc_LONGRUN_competition/"
# results_dir = "../results/sm_gru_ctc_diphones/"
# results_dir = "../results/mfcc_sm_mamba_ctc_LONGRUN/"
results_dir = "../results/gru_ctc/"
# results_dir = "/data/matteo/speech_decoding_BCI/results/gru_ctc_mfcc_bart"

# results_dir = "/data/matteo/speech_decoding_BCI/after_go_exps/results/gru_ctc/aftergo"

if COMPETITION_DATA:
    pred_logits = pickle.load(open(os.path.join(results_dir, "test_pred_logits.pkl"), "rb"))
else:
    pred_logits = pickle.load(open(os.path.join(results_dir, "pred_logits.pkl"), "rb"))

print(f"pred_logits len: {len(pred_logits)}")

if not COMPETITION_DATA:
    # load the ground truth labels
    df = pd.read_csv(os.path.join(results_dir, "results.csv")) 



######


def my_rescore_with_gpt2(model, tokenizer, hypotheses, lengthPenalty):
    model_class = type(model).__name__

    inputs = tokenizer(hypotheses, return_tensors='pt', padding=True)
    with torch.no_grad():
        inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
        outputs = model(**inputs)
        
        logProbs = torch.nn.functional.log_softmax(outputs['logits'].cpu().float(), -1).numpy()

    newLMScores = []
    B, T, _ = logProbs.shape
    for i in range(B):
        n_tokens = np.sum(inputs['attention_mask'][i].cpu().numpy())

        newLMScore = 0.
        for j in range(1, n_tokens):
            newLMScore += logProbs[i, j - 1, inputs['input_ids'][i, j].cpu().numpy()]

        newLMScores.append(newLMScore - n_tokens * lengthPenalty)

    return newLMScores

def my_gpt2_lm_decode(model, tokenizer, nbest, acousticScale, lengthPenlaty, alpha,
                   returnConfidence=False):
    hypotheses = []
    acousticScores = []
    oldLMScores = []
    for out in nbest:
        hyp = out[0].strip()
        if len(hyp) == 0:
            continue
        hyp = hyp.replace('>', '')
        hyp = hyp.replace('  ', ' ')
        hyp = hyp.replace(' ,', ',')
        hyp = hyp.replace(' .', '.')
        hyp = hyp.replace(' ?', '?')
        hypotheses.append(hyp)
        acousticScores.append(out[1])
        oldLMScores.append(out[2])

    if len(hypotheses) == 0:
        return "" if not returnConfidence else ("", 0.)

    acousticScores = np.array(acousticScores)
    newLMScores = np.array(my_rescore_with_gpt2(model, tokenizer, hypotheses, lengthPenlaty))
    oldLMScores = np.array(oldLMScores)

    totalScores = alpha * newLMScores + (1 - alpha) * oldLMScores + acousticScale * acousticScores
    maxIdx = np.argmax(totalScores)
    bestHyp = hypotheses[maxIdx]
    if not returnConfidence:
        return bestHyp
    else:
        totalScores = totalScores - np.max(totalScores)
        probs = np.exp(totalScores)
        return bestHyp, probs[maxIdx] / np.sum(probs)



# # Load OPT 6B model
# llm, llm_tokenizer = lmDecoderUtils.build_opt(
#     cacheDir="/data/", device="auto", load_in_8bit=True
# )

# LM decoding hyperparameters
acoustic_scale = 0.5
blank_penalty = np.log(7)
llm_weight = 0.5

if not PRELOAD_NBEST:
    ngramDecoder = lmDecoderUtils.build_lm_decoder(
        "/data/speech_5gram/lang_test", acoustic_scale=0.5, nbest=100, beam=18
    )



llm, llm_tokenizer = lmDecoderUtils.build_gpt2_torch()

# if items of pred_logits are tensor, convert to cpu and numpy
for i in range(len(pred_logits)):
    if isinstance(pred_logits[i], torch.Tensor):
        pred_logits[i] = pred_logits[i].cpu().numpy()
        

if not COMPETITION_DATA:
    sentences = df["True Sentence"].tolist()    

logits_unfolded = [item for sublist in pred_logits for item in sublist]

print(f"logits_unfolded len: {len(logits_unfolded)}")

llm_outputs = []
# Generate nbest outputs from 5gram LM
start_t = time.time()

if not PRELOAD_NBEST:
    nbest_outputs = []
    for j in tqdm.trange(len(logits_unfolded)):
        logits = logits_unfolded[j]
        logits = np.concatenate(
            [logits[:, 1:], logits[:, 0:1]], axis=-1
        )  # Blank is last token
        logits = lmDecoderUtils.rearrange_speech_logits(logits[None, :, :], has_sil=True)
        nbest = lmDecoderUtils.lm_decode(
            ngramDecoder,
            logits[0],
            blankPenalty=blank_penalty,
            returnNBest=True,
            rescore=True,
        )
        nbest_outputs.append(nbest)


#save nbest_outputs as pickle
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

if not PRELOAD_NBEST:
    pickle.dump(nbest_outputs, open(os.path.join(results_dir, "nbest_outputs.pkl"), "wb"))

else:
    nbest_outputs = pickle.load(open(os.path.join(results_dir, "nbest_outputs.pkl"), "rb"))

### move lm to device
llm = llm.to(DEVICE)

decoded_sentences = []
confidences = []
for i in tqdm.trange(len(nbest_outputs)):
    nbest_output = nbest_outputs[i]
    decoded, confidence = my_gpt2_lm_decode(llm, llm_tokenizer, nbest_output, acoustic_scale, 0, alpha=llm_weight, returnConfidence=True)
    decoded_sentences.append(decoded)
    confidences.append(confidence)


if COMPETITION_DATA:
    res_df = pd.DataFrame({"id": np.arange(len(decoded_sentences)),
    "text": decoded_sentences,})
    res_df.to_csv(f"{results_dir}/decoded_competition_wfst.csv", index=False)
    print("fine")
    quit()
    
    
import re

def clean_prediction(pred_sentence, remove_word_repeats=True, min_char_repeat=4):
    """
    Cleans a predicted sentence by removing repeated characters/words and trailing garbage.

    Args:
        pred_sentence (str): Raw predicted sentence.
        remove_word_repeats (bool): If True, removes repeated word sequences like 'well well well'.
        min_char_repeat (int): Threshold above which repeated characters (e.g., 'd d d d') are removed.

    Returns:
        str: Cleaned sentence.
    """
    s = pred_sentence.lower()

    # Remove character-level repetition like "d d d d d"
    s = re.sub(rf'\b(\w)(?:\s\1){{{min_char_repeat - 1},}}\b', r'\1', s)

    # Optionally remove repeated words like "well well well"
    if remove_word_repeats:
        s = re.sub(r'\b(\w+)(?:\s+\1){2,}\b', r'\1', s)

    # Collapse multiple spaces
    s = re.sub(r'\s{2,}', ' ', s)

    # Strip leading/trailing punctuation and whitespace
    s = s.strip(" ,.;!?\"'-\n\t")

    # Remove final "space + single character" if it looks like garbage (e.g., "everything d")
    s = re.sub(r'\s+\w$', '', s)

    return s

import string
def preprocess_text(text):
    """
    Remove punctuation, strip, and convert text to lowercase.
    """
    return text.translate(str.maketrans('', '', string.punctuation)).strip().lower()


import jiwer  # For WER
import sacrebleu  # For BLEU
from rouge_score import rouge_scorer  # For ROUGE
from nltk.translate.meteor_score import meteor_score  # For METEOR
import bert_score  # For BERTScore
import numpy as np

def compute_metrics(text_transcriptions, gpt_decoded):
    """
    Compute various NLP evaluation metrics for text generation.

    Args:
        text_transcriptions (list): List of ground-truth reference sentences.
        gpt_decoded (list): List of model-generated sentences.

    Returns:
        dict: Dictionary containing all computed metrics.
    """

    #remove punctuation, strip and lower case


    text_transcriptions = [preprocess_text(text) for text in text_transcriptions]
    gpt_decoded = [preprocess_text(text) for text in gpt_decoded]

    results = {}

    # WER (Word Error Rate)
    wer = jiwer.wer(text_transcriptions, gpt_decoded)
    results["WER"] = wer

    # BLEU Score
    bleu = sacrebleu.corpus_bleu(gpt_decoded, [text_transcriptions]).score
    results["BLEU"] = bleu

    # ROUGE Scores
    rouge = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
    rouge_scores = [rouge.score(ref, pred) for ref, pred in zip(text_transcriptions, gpt_decoded)]
    results["ROUGE-1"] = np.mean([score["rouge1"].fmeasure for score in rouge_scores])
    results["ROUGE-2"] = np.mean([score["rouge2"].fmeasure for score in rouge_scores])
    results["ROUGE-L"] = np.mean([score["rougeL"].fmeasure for score in rouge_scores])

    ##METEOR
    tokenized_references = [ref.split() for ref in text_transcriptions]  # Tokenize reference sentences
    tokenized_hypotheses = [pred.split() for pred in gpt_decoded]  # Tokenize predicted sentences

    meteor_scores = [meteor_score([ref], pred) for ref, pred in zip(tokenized_references, tokenized_hypotheses)]
    results["METEOR"] = np.mean(meteor_scores)
    # BERTScore (Semantic Similarity)
    P, R, F1 = bert_score.score(gpt_decoded, text_transcriptions, lang="en", rescale_with_baseline=True)
    results["BERTScore_Precision"] = P.mean().item()
    results["BERTScore_Recall"] = R.mean().item()
    results["BERTScore_F1"] = F1.mean().item()

    ## save also all values without recomputing when possible
    results["METEOR_scores"] = meteor_scores
    results["ROUGE_scores"] = rouge_scores

    results["WER_scores"] = [jiwer.wer([ref], [pred]) for ref, pred in zip(text_transcriptions, gpt_decoded)]
    results["BERTScore_F1_scores"] = F1.cpu().numpy().tolist()
    return results

CLEAN = False
if CLEAN:
    # Clean the decoded sentences
    decoded_sentences = [clean_prediction(sentence) for sentence in decoded_sentences]


metrics = compute_metrics(sentences,decoded_sentences)
for metric, score in metrics.items():
    if "scores" not in metric:
        print(f"{metric}: {score:.4f}")


import pandas as pd

results_df = pd.DataFrame({
    "target_sentence": sentences,
    "pred_sentence": decoded_sentences,
})

#unfold cer_list
# cer_list_unfold = [item for sublist in cer_list for item in sublist]

results_df["WER_scores"] = metrics["WER_scores"]
results_df["METEOR_scores"] = metrics["METEOR_scores"]
results_df["ROUGE_scores"] = metrics["ROUGE_scores"]
results_df["BERTScore_F1_scores"] = metrics["BERTScore_F1_scores"]

results_df.to_csv(f"{results_dir}/language_results.csv", index=False)

overall_metrics = {k:v for k,v in metrics.items() if "scores" not in k}

metrics_df = pd.DataFrame(overall_metrics, index=[0])
metrics_df.to_csv(f"{results_dir}/language_metrics.csv", index=False)


results_df.sort_values("WER_scores", ascending=False).head(20)