import copy
import json
import os

import torch
from matplotlib import pyplot as plt
from IPython.display import display, Audio


def get_devices(print_info=False):
    """Get the available devices and optionally print CUDA info."""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    num_devices = torch.cuda.device_count()
    if print_info:
        print(f"Using {device} with {num_devices} GPUs....")
        for i in range(torch.cuda.device_count()):
            print("Name:", torch.cuda.get_device_name(i))
            print("-  Memory Allocated:", torch.cuda.memory_allocated(i) / 1024 ** 3, "GB")
            print("-  Memory Cached:", torch.cuda.memory_reserved(i) / 1024 ** 3, "GB")
    return device, num_devices


def plot_audio_pair(audio_sample: dict):
    print(audio_sample["props"])
    for i, p in enumerate(audio_sample["prompts"]):
        clip = audio_sample["audio_pair"][i, ...].squeeze()
        display(Audio(clip, rate=audio_sample["sr"]))

        if clip.ndim == 2:
            clip = torch.mean(clip, dim=0).squeeze()

        # plot
        plt.figure(figsize=(10, 4))
        plt.plot(clip)
        plt.title(f"Waveform of '{p}'")
        plt.xlabel(f"Time (samples)")
        plt.ylabel(f"Amplitude")
        plt.show()


def batch_list(lst, batch_size):
    for i in range(0, len(lst), batch_size):
        yield lst[i:i + batch_size]


def check_config_match(args) -> bool:
    os.makedirs(args.output_path, exist_ok=True)
    config_path = os.path.join(args.output_path, "config.json")
    args_dict = copy.deepcopy(vars(args))
    del args_dict["n_chunks"]
    del args_dict["curr_chunk"]
    args_dict.pop("batch_size", None)

    try:
        with open(config_path, "r") as config_file:
            config_data = json.load(config_file)
            return config_data == args_dict
    except IOError:
        # no config file so not continuing from previous configuration
        with open(config_path, "w") as config_file:
            json.dump(obj=args_dict, fp=config_file)

    return True














