import os
import subprocess
import pandas as pd
import numpy as np
import librosa
import sys
from tqdm import tqdm
import json

def download_audiomnist(repo_path="AudioMNIST"):
    """
    Clones the AudioMNIST repository from GitHub if it doesn't already exist.
    """
    if os.path.exists(repo_path):
        print(f"'{repo_path}' directory already exists. Skipping download.")
        return
    
    print("Cloning the AudioMNIST repository...")
    try:
        repo_url = "https://github.com/soerenab/AudioMNIST.git"
        subprocess.run(["git", "clone", repo_url], check=True, capture_output=True, text=True)
        print("Repository cloned successfully.")
        
        lfs_path = os.path.join(repo_path, ".gitattributes")
        if os.path.exists(lfs_path):
            print("Git LFS detected, pulling audio files...")
            try:
                # Attempt to install git-lfs if not present
                subprocess.run([sys.executable, "-m", "pip", "install", "git-lfs"], check=True, capture_output=True, text=True, stderr=subprocess.PIPE)
            except subprocess.CalledProcessError:
                # This might fail if pip isn't in a standard location, but it's worth a try.
                print("Could not auto-install git-lfs. Please ensure it's installed manually (see https://git-lfs.com).")

            subprocess.run(["git", "lfs", "install"], check=True, capture_output=True, text=True)
            subprocess.run(["git", "lfs", "pull"], cwd=repo_path, check=True, capture_output=True, text=True)
            print("Git LFS files pulled successfully.")

    except subprocess.CalledProcessError as e:
        print("An error occurred while cloning the repository or pulling LFS files.")
        print(f"URL: {repo_url}")
        print(f"Error: {e.stderr}")
    except FileNotFoundError:
        print("Error: 'git' command not found. Please install Git and ensure it's in your PATH.")

def create_spectrogram(file_path, target_shape=(112, 112)):
    """
    Loads a .wav file and creates a normalized Mel spectrogram of a target shape.
    """
    try:
        y, sr = librosa.load(file_path, sr=None)
        
        mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=target_shape[0], n_fft=512, hop_length=128)
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
        
        current_width = mel_spec_db.shape[1]
        target_width = target_shape[1]
        
        if current_width > target_width:
            mel_spec_db = mel_spec_db[:, :target_width]
        elif current_width < target_width:
            pad_width = target_width - current_width
            mel_spec_db = np.pad(mel_spec_db, ((0, 0), (0, pad_width)), mode='min')

        spec_min, spec_max = mel_spec_db.min(), mel_spec_db.max()
        if (spec_max - spec_min) > 1e-6:
            normalized_spec = (mel_spec_db - spec_min) / (spec_max - spec_min)
        else:
            normalized_spec = np.zeros(target_shape)

        return normalized_spec.astype(np.float32)

    except Exception as e:
        print(f"Could not process {file_path}: {e}", file=sys.stderr)
        return np.zeros(target_shape, dtype=np.float32)

def process_and_save_audiomnist(repo_path="AudioMNIST", output_dir="avmnist_data_from_source"):
    """
    Processes AudioMNIST data, creates spectrograms, splits by speaker,
    and saves them as numpy arrays compatible with the AVMNIST pipeline.
    """
    data_dir = os.path.join(repo_path, "data")
    meta_path = os.path.join(repo_path, "audioMNIST_meta.txt")

    if not os.path.exists(data_dir):
        print(f"Error: Data directory not found at '{data_dir}'", file=sys.stderr)
        return

    try:
        # Correctly read the metadata file as JSON
        with open(meta_path, 'r') as f:
            meta_data = json.load(f)
        print("Successfully loaded JSON metadata file.")
        # The script proceeds by reading file paths, which is sufficient for our needs.
    except FileNotFoundError:
        print(f"Error: Metadata file not found at '{meta_path}'", file=sys.stderr)
        return
    except json.JSONDecodeError:
        print(f"Error: Could not parse JSON from '{meta_path}'. The file might be corrupted.", file=sys.stderr)
        return

    audio_records = []
    for speaker_id_str in os.listdir(data_dir):
        speaker_dir = os.path.join(data_dir, speaker_id_str)
        if os.path.isdir(speaker_dir):
            try:
                speaker_id = int(speaker_id_str)
                for filename in os.listdir(speaker_dir):
                    if filename.endswith(".wav"):
                        parts = filename.replace(".wav", "").split("_")
                        if len(parts) == 3:
                            digit, _, instance = parts
                            audio_records.append({
                                "file_path": os.path.join(speaker_dir, filename),
                                "speaker_id": speaker_id,
                                "digit": int(digit)
                            })
            except (ValueError, IndexError):
                continue
    
    full_df = pd.DataFrame(audio_records)

    train_speakers = list(range(1, 51))
    test_speakers = list(range(51, 61))
    train_df = full_df[full_df['speaker_id'].isin(train_speakers)].reset_index(drop=True)
    test_df = full_df[full_df['speaker_id'].isin(test_speakers)].reset_index(drop=True)

    print(f"Train samples: {len(train_df)}, Test samples: {len(test_df)}")

    os.makedirs(os.path.join(output_dir, 'audio'), exist_ok=True)
    
    for name, df in [('train', train_df), ('test', test_df)]:
        print(f"Processing {name} audio spectrograms...")
        audio_data = [create_spectrogram(fp) for fp in tqdm(df['file_path'])]
        
        np.save(os.path.join(output_dir, f'audio/{name}_data.npy'), np.array(audio_data))
        np.save(os.path.join(output_dir, f'{name}_labels.npy'), df['digit'].values)
        np.save(os.path.join(output_dir, f'{name}_speaker_labels.npy'), df['speaker_id'].values)

    print(f"\nData successfully processed and saved to '{output_dir}'")

def main():
    repo_directory = "AudioMNIST"
    output_directory = "avmnist_data_from_source"
    
    download_audiomnist(repo_directory)
    process_and_save_audiomnist(repo_directory, output_directory)

if __name__ == "__main__":
    main()

