import argparse
import random
import numpy as np
import pandas as pd
import torch
from conf import (
    aves_bio_config,
    aves_bio_model,
    aves_core_config,
    aves_core_model,
    dolph2vec_config_path,
)
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict

from models import MFCC, Aves, BioLingual, Dolph2Vec, SpectralFeatures, Spectrogram, W2VQuantizer
from utils import compute_information_metrics, compute_cooccurrence_matrix, plot_cooccurrence_matrix, plot_umap


def set_seed(seed: int = 42):
    torch.random.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # for multi-GPU
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument(
        "--dataset_name",
        choices=["dolphin_reef_balanced", "dolphin_reef_unbalanced", "watkins", "watkins_dolphins", "watkins_SW"]
    )

    parser.add_argument(
        "--model",
        choices=[
            "custom_quant"
        ],
        default="custom_quant",
    )
    parser.add_argument("--outfolder", default="", type=str)

    parser.add_argument("--target_sample_rate", default=44100, type=int)

    return parser.parse_args()



def main():
    args = get_args()
    set_seed(args.seed)

    name2model = {
        "custom_quant": W2VQuantizer,  # returns quantized + indeces
    }

    model_args = dict(
        sample_rate=args.target_sample_rate,
        dolph2vec_config_path=dolph2vec_config_path,
    )

    model = name2model[args.model](**model_args)

    dataset_name2path = {
        "dolphin_reef_balanced": "data/dolphin_reef/balanced/all.csv",
        "dolphin_reef_unbalanced": "data/dolphin_reef/unbalanced/all.csv",
        "watkins": "",
        "watkins_dolphins": "",
        "watkins_SW": "data/watkins/watkins_SW_small.csv",
    }

    data_path = dataset_name2path[args.dataset_name]

    df = pd.read_csv(data_path)

    embeddings = []
    indices = []
    labels = []
    for i, row in tqdm(
            df.iterrows(), desc="processing audio files", total=len(df)
    ):
        path = row["path"]
        label = row["label"]

        # try:
        embedding, seq_indices = model(path)
        embeddings.append(embedding.cpu())
        indices.append(seq_indices)
        labels.append(label)

    #   except Exception as e:
    #       print(f"error processing {path}: {e}")


    co_matrix, probs, unique_labels = compute_cooccurrence_matrix(indices, labels)
    metrics = compute_information_metrics(probs)
    for k, v in metrics.items():
        print(f"{k}: {v:.4f}")

    plot_cooccurrence_matrix(probs, unique_labels, args.outfolder)
    x = np.array(embeddings)
    y = np.array(labels)
    plot_umap(x, y, args.outfolder)

#   np.savez(f'{args.outfolder}/indices{args.model}_{args.dataset_name}.npz', indices)
#   np.save(f'{args.outfolder}/embeddings_{args.model}_{args.dataset_name}.npy', x)
#   np.save(f'{args.outfolder}/labels_{args.dataset_name}.npy', y)


if __name__ == "__main__":
    main()

