import io
import time
import torch
import torchaudio

import numpy as np
import matplotlib.pyplot as plt
from .utils.eval_utmos import eval_utmos_fname
from .utils.eval_utilization import eval_utilization
from .utils.eval_sim import extract_wavlm_similarity
import glob
import pickle
import pandas as pd
import editdistance
from transformers import Wav2Vec2Processor, HubertForCTC
import whisper.normalizers
SIM_METRIC = 'SIM'

EVAL_UTMOS = True

assert SIM_METRIC in ['SIM_O', 'SIM_R', 'SIM', 'SIM_OR']

import os
import sys
IS_CLUSTER = 'IS_CLUSTER' in os.environ and os.environ['IS_CLUSTER'] == '1'

# Define a mapping of short labels to data paths
data_map = {
    'librispeech-cropped': 'librispeech-cropped/*/*/*.wav',
    'librispeech': 'test-clean/*/*/*.flac',
}
# ----------------------- DEFINITON OF MODEL UNDER EVAL -----------------------
# from .models.official_xcodec2 import prepare_model, infer
# from .models.official_snac import prepare_model, infer
# from .models.official_xytokenizer import prepare_model, infer
# from .models.official_dualcodec import prepare_model, infer

# from .models.retrain_speechtokenizer import prepare_model, infer
# from .models.official_mimi import prepare_model, infer
# from .models.official_speechtokenizer import prepare_model, infer
# from .models.official_encodec import prepare_model, infer
# from .models.official_semanticodec import prepare_model, infer
# from .models.official_xcodec import prepare_model, infer

# ----------------------- END DEFINITON OF MODEL UNDER EVAL -----------------------

import argparse
import os
import sys
import typing as tp
from collections import OrderedDict
from pathlib import Path
import torchaudio
import torch
from .utils.dataset import TestDataset, JsonDataset
from pesq import pesq
from pystoi import stoi
from tqdm import tqdm
from pathlib import Path

def test_one(args, wav: torch.Tensor, store_root, store_root_ref, soundstream, transcript=None, asr_processor=None, asr_model=None, normalizer=None):
    ret = dict()
    wav.requires_grad = False
    kl = None

    vq_metrics = {}  # Store metrics for each VQ option

    if args.nq:
        # Wrap the per-VQ evaluation logic into a local function so that we can run it
        # concurrently using a thread pool when requested.

        def _eval_single_vq(q):
            """Evaluate a single VQ setting and return (q, metrics_dict)."""
            # ---------------- Original body starts ----------------
            if EVALUATE_OFFLINE:
                fname = Path(store_root).name
                waveform, sr = torchaudio.load(os.path.join(EVALUATE_OFFLINE_ROOT, fname))
                resampled_waveform = torchaudio.transforms.Resample(sr, args.sr)(waveform)
                out = {'out': resampled_waveform, 'compressed': None, 'encode_rtf': 0.0, 'decode_rtf': 0.0}
            else:
                out = infer(wav, soundstream, num_quantizers=q)
            encode_rtf = None
            decode_rtf = None
            sim_values = None
            if isinstance(out, dict):
                compressed = out.get('compressed')
                encode_rtf = out.get('encode_rtf')
                decode_rtf = out.get('decode_rtf')
                feature = out.get('semantic_features')
                sim_values = out.get('sim')  # Extract sim values
                if feature is not None:
                    torch.save(feature.cpu(), store_root.replace('.wav', f'.pt'))
                    print(f'feature saved to {store_root.replace(".wav", ".pt")}')
                out = out['out']
            elif len(out) == 1:
                out = out
                compressed = None
            if len(out) == 2:
                out, compressed = out
            elif len(out) == 4:
                out, compressed, encode_rtf, decode_rtf = out
            elif len(out) == 5:
                out, compressed, encode_rtf, decode_rtf, feature = out
                if args.feature:
                    torch.save(feature.cpu(), store_root.replace('.wav', f'.pt'))
                    print(f'feature saved to {store_root.replace(".wav", ".pt")}')

            out = out.detach().cpu().reshape(-1)
            wav_aligned = wav.detach().cpu().reshape(-1)
            min_len = min(len(wav_aligned), len(out))
            wav_aligned = wav_aligned[:min_len]
            out = out[:min_len]
            metrics = {}
            if args.save_audio:
                torchaudio.save(store_root.replace('.wav', f'_{q}q.wav').replace('.flac', f'_{q}q.wav').replace('.mp3', f'_{q}q.wav').replace('.opus', f'_{q}q.wav'), out.unsqueeze(0), args.sr)
                print(f"saved to {store_root.replace('.wav', f'_{q}q.wav').replace('.flac', f'_{q}q.wav')}")
            if args.save_ref:
                store_root_ref_q = store_root_ref.replace('.mp3', '.wav').replace('.opus', '.wav')
                if SIM_METRIC == 'SIM':
                    torchaudio.save(store_root_ref_q, wav_aligned.unsqueeze(0), args.sr)
                elif SIM_METRIC == 'SIM_O':
                    torchaudio.save(store_root_ref_q.replace('sim_r', 'sim_o'), wav_aligned[:args.sr*3].unsqueeze(0), args.sr)
                    print(f'ref saved to {store_root_ref_q}')
                elif SIM_METRIC == 'SIM_R':
                    torchaudio.save(store_root_ref_q.replace('sim_o', 'sim_r'), out[:args.sr*3].unsqueeze(0), args.sr)
                    print(f'ref saved to {store_root_ref_q}')
                else:
                    torchaudio.save(store_root_ref_q.replace('reference_3s_sim_r', 'reference_full'), wav_aligned.unsqueeze(0), args.sr)
                    print(f'ref saved to {store_root_ref_q.replace("reference_3s_sim_r", "reference_full")}')
                    torchaudio.save(store_root_ref_q.replace('sim_r', 'sim_o'), wav_aligned[:args.sr*3].unsqueeze(0), args.sr)
                    print(f"ref saved to {store_root_ref_q.replace('sim_r', 'sim_o')}")
                    torchaudio.save(store_root_ref_q.replace('sim_o', 'sim_r'), out[:args.sr*3].unsqueeze(0), args.sr)
                    print(f"ref saved to {store_root_ref_q.replace('sim_o', 'sim_r')}")
            code_utilization = eval_utilization(compressed)
            print(f'code_util: {code_utilization}')
            metrics['code_util'] = code_utilization
            if sim_values is not None:
                metrics['codec_sim_values'] = sim_values  # Store raw sim values for analysis
            if args.save_audio and q == args.nq[-1]:
                torchaudio.save(store_root, out.unsqueeze(0), args.sr)
                torchaudio.save(store_root.replace('.wav', f'_0q.wav'), wav_aligned.unsqueeze(0), args.sr)
            wav16 = torchaudio.transforms.Resample(args.sr, 16000)(wav_aligned)
            out16 = torchaudio.transforms.Resample(args.sr, 16000)(out)
            if args.wer:
                if transcript and asr_model and asr_processor and normalizer:
                    device = asr_model.device
                    input_values = asr_processor(
                        out16, return_tensors="pt", sampling_rate=16000
                    ).input_values
                    logits = asr_model(input_values.to(device)).logits
                    predicted_ids = torch.argmax(logits, dim=-1)
                    hypothesis_text = asr_processor.decode(predicted_ids[0])
                    ground_truth_text = transcript
                    if normalizer is not None:
                        hypothesis_text = normalizer(hypothesis_text)
                        ground_truth_text = normalizer(ground_truth_text)
                    error = editdistance.eval(ground_truth_text.split(" "), hypothesis_text.split(" "))
                    num_word = len(ground_truth_text.split(" "))
                    metrics['wer_error'] = error
                    metrics['wer_num_word'] = num_word
                    if num_word > 0:
                        print(
                            f"Path: {store_root} Error: {error} Num word: {num_word} Local WER: {error / num_word if num_word > 0 else 0} Hypothesis: {hypothesis_text} Ground truth: {ground_truth_text}"
                        )
            if args.infer:
                return q, metrics
            stoi_score = stoi(wav_aligned, out, args.sr, extended=False)
            metrics['stoi'] = stoi_score
            if EVAL_UTMOS:
                # Use the correct filename with VQ suffix for UTMOS evaluation
                utmos_filename = store_root.replace('.wav', f'_{q}q.wav').replace('.flac', f'_{q}q.wav').replace('.mp3', f'_{q}q.wav').replace('.opus', f'_{q}q.wav')
                utmos_score = eval_utmos_fname(utmos_filename)
            else:
                utmos_score = 0.0
            metrics['utmos'] = utmos_score
            # SIM metric evaluation
            try:
                import cached_path
                sim_score = extract_wavlm_similarity(out16, wav16, device='cuda', checkpoint_path=checkpoint_path)
            except Exception as e:
                print(f'Error computing SIM: {e}')
                sim_score = -1.
            metrics['sim'] = sim_score
            # ------ PESQ ------
            try:
                pesq_wb = pesq(16000, wav16.numpy(), out16.numpy(), 'wb')
                pesq_nb = pesq(16000, wav16.numpy(), out16.numpy(), 'nb')
            except:
                pesq_wb = -1.
                pesq_nb = -1.
            metrics['pesq_wb'] = pesq_wb
            metrics['pesq_nb'] = pesq_nb
            # ------ Visqol ------
            try:
                visqol_score = eval_visqol(wav16.numpy(), out16.numpy())
            except:
                visqol_score = -1.
            metrics['visqol'] = visqol_score
            # ------ MCD ------
            from .utils.eval_mcd import eval_mcd
            mcd_score = eval_mcd(wav16.numpy(), out16.numpy(), sr=16000)
            metrics['mcd'] = mcd_score
            if encode_rtf is not None:
                metrics['encode_rtf'] = encode_rtf
            if decode_rtf is not None:
                metrics['decode_rtf'] = decode_rtf
            return q, metrics
            # ---------------- Original body ends ----------------

        # Decide whether to run sequentially or in parallel.
        if args.num_workers > 1 and len(args.nq) > 1:
            # Do a warm-up pass on the first VQ setting to ensure torch.compile has
            # finished before we enter multi-threaded execution. This avoids the FX /
            # dynamo tracing race condition seen when compiling concurrently.
            first_q, *remaining_qs = args.nq
            q_val, met = _eval_single_vq(first_q)
            vq_metrics[q_val] = met

            from concurrent.futures import ThreadPoolExecutor
            with ThreadPoolExecutor(max_workers=min(args.num_workers, len(remaining_qs))) as executor:
                for q, met in executor.map(_eval_single_vq, remaining_qs):
                    vq_metrics[q] = met
        else:
            # Either single-threaded or only one VQ to evaluate.
            for q in args.nq:
                q_val, met = _eval_single_vq(q)
                vq_metrics[q] = met
    return vq_metrics
