import argparse

import torch
from transformers import WhisperFeatureExtractor

from config import Config
from models.salmonn import SALMONN
from utils import prepare_one_sample

import pandas as pd
from collections import defaultdict

import evaluate
import os
import tqdm



class Metrics:
    def __init__(self,outputs,wer,cer,bleu,meteor,sacrebleu,rouge1,chrf):
        self.outputs = outputs
        self.wer = wer
        self.cer = cer
        self.bleu = bleu
        self.meteor = meteor
        self.sacrebleu = sacrebleu
        self.rouge1 = rouge1
        self.chrf = chrf
        
    def add_metrics(self,pr_latex,gt_latex,pron):
        self.outputs['latex_pred'].append(pr_latex)
        self.outputs['latex_true'].append(gt_latex)
        self.outputs['pron'].append(pron)
        try:
            self.outputs['cer'].append(self.cer.compute(predictions=[pr_latex], references=[gt_latex]))
            self.outputs['wer'].append(self.wer.compute(predictions=[pr_latex], references=[gt_latex]))

            self.outputs['rouge1'].append(self.rouge1.compute(predictions=[pr_latex], references=[gt_latex])['rouge1'])
            self.outputs['chrf'].append(self.chrf.compute(predictions=[pr_latex], references=[gt_latex])['score'] / 100)
            self.outputs['chrfpp'].append(self.chrf.compute(predictions=[pr_latex], references=[gt_latex], word_order=2)['score'] / 100)
            self.outputs['bleu'].append(self.bleu.compute(predictions=[pr_latex], references=[gt_latex])['bleu'])
            self.outputs['sbleu'].append(self.sacrebleu.compute(predictions=[pr_latex], references=[gt_latex])['score'] / 100)
            self.outputs['meteor'].append(self.meteor.compute(predictions=[pr_latex], references=[gt_latex])['meteor'])

            self.outputs['cer_lower'].append(self.cer.compute(predictions=[pr_latex.lower()], references=[gt_latex.lower()]))
            self.outputs['wer_lower'].append(self.wer.compute(predictions=[pr_latex.lower()], references=[gt_latex.lower()]))

            self.outputs['rouge1_lower'].append(self.rouge1.compute(predictions=[pr_latex.lower()], references=[gt_latex.lower()])['rouge1'])
            self.outputs['chrf_lower'].append(self.chrf.compute(predictions=[pr_latex.lower()], references=[gt_latex.lower()])['score'] / 100)
            self.outputs['chrfpp_lower'].append(self.chrf.compute(predictions=[pr_latex.lower()], references=[gt_latex.lower()], word_order=2)['score'] / 100)
            self.outputs['bleu_lower'].append(self.bleu.compute(predictions=[pr_latex.lower()], references=[gt_latex.lower()])['bleu'])
            self.outputs['sbleu_lower'].append(self.sacrebleu.compute(predictions=[pr_latex.lower()], references=[gt_latex.lower()])['score'] / 100)
            self.outputs['meteor_lower'].append(self.meteor.compute(predictions=[pr_latex.lower()], references=[gt_latex.lower()])['meteor'])
        except ValueError:
            print(f"{pron=}, {gt_latex=}")
    
    def save_csv(self,output_path):
        res_df = pd.DataFrame(self.outputs)
        res_df.to_csv(output_path, index=False)

        print("cer",res_df["cer"].mean() )
        print("wer",res_df["wer"].mean())
        print("rouge1",res_df["rouge1"].mean())
        print("chrf",res_df["chrf"].mean())
        print("chrfpp",res_df["chrfpp"].mean())
        print("bleu",res_df["bleu"].mean())
        print("sbleu",res_df["sbleu"].mean())
        print("meteor",res_df["meteor"].mean())
        print("cer_lower",res_df["cer_lower"].mean())
        print("wer_lower",res_df["wer_lower"].mean())
        print("rouge1_lower",res_df["rouge1_lower"].mean())
        print("chrf_lower",res_df["chrf_lower"].mean())
        print("chrfpp_lower",res_df["chrfpp_lower"].mean())
        print("bleu_lower",res_df["bleu_lower"].mean())
        print("sbleu_lower",res_df["sbleu_lower"].mean())
        print("meteor_lower",res_df["meteor_lower"].mean())


if __name__ == "__main__":

    path_res_test = "/home/jovyan/Nikita/speech2latex/salmonn_exps/SALMONN/test/output/test_predicted_ru.xlsx"
    output_test_metrics_path = "/home/jovyan/Nikita/speech2latex/salmonn_exps/SALMONN/test/output/metrics_ru_0.csv"
    df = pd.read_excel(path_res_test)

    outputs = defaultdict(list)
    wer = evaluate.load('wer')
    cer = evaluate.load('cer')
    bleu = evaluate.load('bleu')
    meteor = evaluate.load('meteor')
    sacrebleu = evaluate.load("sacrebleu")
    rouge1 = evaluate.load("rouge")
    chrf = evaluate.load("chrf")

    metrics = Metrics(outputs,wer,cer,bleu,meteor,sacrebleu,rouge1,chrf)

    for _,row in tqdm.tqdm(df.iterrows()):
        gt_latex = row["gt_latex"]
        predict_latex = row["pr_latex"]
        pron = row["pron"]
        # wav_path = row["audio_path"]
        metrics.add_metrics(predict_latex,gt_latex,pron)

    metrics.save_csv(output_test_metrics_path)


