import os
import json
import torch
from inference import InferenceHandler
from glob import glob
import os
from tqdm import tqdm
import librosa
import hydra
import numpy as np
from evaluate import evaluate_main
from tasks.mt3_net import MT3Net


def get_scores(
    model,
    eval_audio_dir=None,
    mel_norm=True,
    eval_dataset="Score_Informed",
    exp_tag_name="test_midis",
    ground_truth=None,
    verbose=True,
    contiguous_inference=False,
    batch_size=1,
    max_length=1024,
):
    handler = InferenceHandler(
        model=model,
        device=torch.device("cuda"),
        mel_norm=mel_norm,
        contiguous_inference=contiguous_inference,
    )

    def func(fname):
        audio, _ = librosa.load(fname, sr=16000)
        print(f"audio_len in seconds: {len(audio)/16000}")
        return audio
    
    if verbose:
        print("Total audio files:", len(eval_audio_dir))
        


    print(f"batch_size: {batch_size}")

    for audio_file in tqdm(eval_audio_dir, total=len(eval_audio_dir)):
        # Process each file pair here
        print("Processing:", audio_file)
        audio = func(audio_file)

        fname = audio_file.split("/")[-2]
        outpath = os.path.join(exp_tag_name, fname, "mix.mid") 

        handler.inference(
            audio=audio,
            audio_path=fname,
            outpath=outpath,
            batch_size=batch_size,
            max_length=max_length,
            verbose=verbose,
        )

    if verbose:
        print("Evaluating...")
    current_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir

    scores = evaluate_main(
        dataset_name=eval_dataset,
        test_midi_dir=os.path.join(current_dir, exp_tag_name), # TODO: this needs to be split by mistake? No. we can split by instrument later
        ground_truth=ground_truth,
    )

    if verbose:
        for key in sorted(list(scores)):
            print("{}: {:.4}".format(key, scores[key]))

    return scores

def _load_MAESTRO_split_info(json_path):
        with open(json_path, "r") as f:
            data = json.load(f)
        midi_filename_to_number = {
            os.path.basename(path).replace(".midi", ""): str(number)
            for number, path in data["midi_filename"].items()
        }
        split_to_numbers = {split: set() for split in set(data["split"].values())}
        for number, split in data["split"].items():
            split_to_numbers[split].add(str(number))
        return midi_filename_to_number, split_to_numbers
    
def _build_dataset(root_dir, json_path, split, train_mistake_audio=True, train_score_audio=False):
    # Load the mapping and splits
    midi_filename_to_number, split_to_numbers = _load_MAESTRO_split_info(
        json_path
    )
    desired_file_numbers = split_to_numbers[split]

    df = []
    mistakes_audio_dir = []
    scores_audio_dir = []
    
    mistake_pattern = os.path.join(root_dir, "mistake", "**", "mix.wav")
    score_pattern = os.path.join(root_dir, "score", "**", "mix_Audio_Aligned.wav")

    mistake_files = {
        os.path.normpath(f).split(os.sep)[-2]: f
        for f in glob(mistake_pattern, recursive=True)
    }
    
    score_files = {
        os.path.normpath(f).split(os.sep)[-2]: f
        for f in glob(score_pattern, recursive=True)
    }

    # Match files based on the common identifier
    if train_mistake_audio and train_score_audio:
        for track_id in mistake_files.keys() and score_files.keys():
            file_number = midi_filename_to_number.get(track_id)
            if file_number in desired_file_numbers:
                mistakes_audio_dir.append(mistake_files[track_id])
                df.append(
                    {
                        "track_id": track_id, # "track_id": "track_id" is added to the dictionary
                        "midi_path": mistake_files[track_id].replace(".wav", ".mid"),
                        "audio_path": mistake_files[track_id].replace(".mid", ".wav"),
                    }
                )
        for track_id in score_files.keys():
            file_number = midi_filename_to_number.get(track_id)
            if file_number in desired_file_numbers:
                scores_audio_dir.append(score_files[track_id])
                df.append(
                    {
                        "track_id": track_id, # "track_id": "track_id" is added to the dictionary
                        "midi_path": score_files[track_id].replace(".wav", ".mid"),
                        "audio_path": score_files[track_id].replace(".mid", ".wav"),
                    }
                )
    elif train_mistake_audio:
        for track_id in mistake_files.keys():
            file_number = midi_filename_to_number.get(track_id)
            if file_number in desired_file_numbers:
                mistakes_audio_dir.append(mistake_files[track_id])
                df.append(
                    {   "track_id": track_id,
                        "midi_path": mistake_files[track_id].replace(".wav", ".mid"),
                        "audio_path": mistake_files[track_id].replace(".mid", ".wav"),
                    }
                )
    else:
        for track_id in score_files.keys():
            file_number = midi_filename_to_number.get(track_id)
            if file_number in desired_file_numbers:
                scores_audio_dir.append(score_files[track_id])
                df.append(
                    {
                        "track_id": track_id,
                        "midi_path": score_files[track_id].replace(".wav", ".mid"),
                        "audio_path": score_files[track_id].replace(".mid", ".wav"),
                    }
                )
    
            
            
            
    assert len(df) > 0, "No matching files found. Check the dataset directory."
    
    return df, mistakes_audio_dir, scores_audio_dir

@hydra.main(config_path="config", config_name="config", version_base="1.1")
def main(cfg):
    assert cfg.path
    assert (
        cfg.path.endswith(".pt")
        or cfg.path.endswith("pth")
        or cfg.path.endswith("ckpt")
    ), "Only .pt, .pth, .ckpt files are supported."
    assert cfg.eval.exp_tag_name
    
    use_mistake_audio = False
    use_score_audio = True
    
    dataset, mistakes_audio_dir, scores_audio_dir = _build_dataset(
        root_dir="/home/depotdatasets/Score_Informed_with_mistakes_aligned",
        json_path="/home/depotdatasets/Score_Informed_with_mistakes_aligned/split.json",
        split="test",
        train_mistake_audio=use_mistake_audio,
        train_score_audio=use_score_audio
    )

    pl = hydra.utils.instantiate(cfg.model, optim_cfg=cfg.optim)
    print(f"Loading weights from: {cfg.path}")
    # Loading checkpoint based on the extension
    if cfg.path.endswith(".ckpt"):
        # load lightning module from checkpoint
        model_cls = hydra.utils.get_class(cfg.model._target_)
        print("torch.cuda.device_count():", torch.cuda.device_count())
        pl = model_cls.load_from_checkpoint(
            cfg.path,
            config=cfg.model.config,
            optim_cfg=cfg.optim,
        )
        model = pl.model
    else:
        # load weights for nn.Module
        model = pl.model
        if cfg.eval.load_weights_strict is not None:
            model.load_state_dict(
                torch.load(cfg.path), strict=cfg.eval.load_weights_strict
            )
        else:
            model.load_state_dict(torch.load(cfg.path), strict=False)

    model.eval()
    # if torch.cuda.is_available():
    #     model.cuda()

    
        # mistakes_audio_dir = mistakes_audio_dir[: cfg.eval.eval_first_n_examples]
        # scores_audio_dir = scores_audio_dir[: cfg.eval.eval_first_n_examples]

    mel_norm = True
    
    if use_mistake_audio:
        eval_audio_dir = mistakes_audio_dir
    elif use_score_audio:
        eval_audio_dir = scores_audio_dir
    else:
        eval_audio_dir = mistakes_audio_dir + scores_audio_dir
    if cfg.eval.eval_first_n_examples:
        random_offset = 0
        if len(eval_audio_dir) > cfg.eval.eval_first_n_examples:
            random_offset = np.random.randint(0, len(eval_audio_dir) - cfg.eval.eval_first_n_examples)
        eval_audio_dir = eval_audio_dir[random_offset: random_offset + cfg.eval.eval_first_n_examples]

    get_scores(
        model,
        eval_audio_dir=eval_audio_dir,
        mel_norm=mel_norm,
        eval_dataset=cfg.eval.eval_dataset,
        exp_tag_name=cfg.eval.exp_tag_name,
        ground_truth=dataset,
        contiguous_inference=cfg.eval.contiguous_inference,
        batch_size=cfg.eval.batch_size,
    )


if __name__ == "__main__":
    main()
