import argparse
import random

import numpy as np
import pandas as pd
import torch
from conf import (
    aves_bio_config,
    aves_bio_model,
    aves_core_config,
    aves_core_model,
    dolph2vec_config_path,
)
from models import MFCC, Aves, BioLingual, Dolph2Vec, SpectralFeatures, Spectrogram, ShuffledWav2Vec2Wrapper
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm


def set_seed(seed: int = 42):
    torch.random.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # for multi-GPU
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument(
        "--dataset_name", choices=["dolphin_reef_balanced", "dolphin_reef_unbalanced", "watkins", "watkins_dolphins"]
    )

    parser.add_argument(
        "--model",
        choices=[
            "dolph2vec",
            "dolph2vec-shuffle",
            "aves_core",
            "aves_bio",
            "biolingual",
            "mfcc",
            "spectrogram",
            "spectral_features"],
        default="dolph2vec",
    )
    parser.add_argument("--outfolder", default="", type=str)

    parser.add_argument("--target_sample_rate", default=44100, type=int)

    return parser.parse_args()


def main():
    args = get_args()
    set_seed(args.seed)

    name2model = {
        "aves_core": Aves,
        "aves_bio": Aves,
        "biolingual": BioLingual,
        "dolph2vec": Dolph2Vec,
        "dolph2vec-shuffle":ShuffledWav2Vec2Wrapper,
        "mfcc": MFCC,
        "spectrogram": Spectrogram,
        "spectral_features": SpectralFeatures,

    }

    model_args = dict(
        sample_rate=args.target_sample_rate,
        dolph2vec_config_path=dolph2vec_config_path,
    )

    if args.model == "aves_bio":
        aves_model_path = aves_core_model
        aves_config_path = aves_core_config
    elif args.model == "aves_core":
        aves_model_path = aves_bio_model
        aves_config_path = aves_bio_config
    else:
        aves_model_path, aves_config_path = "", ""

    model_args["aves_model_path"] = aves_model_path
    model_args["aves_config_path"] = aves_config_path

    model = name2model[args.model](**model_args)

    dataset_name2path = {
        "dolphin_reef_balanced": "data/dolphin_reef/balanced/all.csv",
        "dolphin_reef_unbalanced": "data/dolphin_reef/unbalanced/all.csv",
        "watkins": "",
        "watkins_dolphins": "",
    }

    data_path = dataset_name2path[args.dataset_name]

    df = pd.read_csv(data_path)

    embeddings = []
    labels = []
    for i, row in tqdm(
        df.iterrows(), desc="processing audio files", total=len(df)
    ):
        path = row["path"]
        label = row["label"]

        try:
            embedding = model(path)
            embeddings.append(embedding.cpu())
            labels.append(label)
        except Exception as e:
            print(f"error processing {path}: {e}")

    x = np.array(embeddings)
    y = np.array(labels)

    np.save(f'{args.outfolder}/embeddings_{args.model}_{args.dataset_name}.npy', x)
    np.save(f'{args.outfolder}/labels_{args.dataset_name}.npy', y)


if __name__ == "__main__":
    main()

