#!/usr/bin/env python3

import argparse
import numpy as np
from datasets import load_from_disk
import random
import pathlib
import os
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def parse_args():
    """
    Parse command-line arguments for dataset path, t-SNE settings, mode, and class subset.
    Returns args with dataset_dir, perplexity, random_state, mode, and classes.
    """
    p = argparse.ArgumentParser(description="Cluster & t-SNE embed pipeline")
    p.add_argument("--dataset-dir", type=str, required=True,
                   help="Path to HF DatasetDict on disk (train/validation/test)")
    p.add_argument("--perplexity", type=float, default=30,
                   help="t-SNE perplexity")
    p.add_argument("--random-state", type=int, default=42,
                   help="Seed for reproducibility")
    p.add_argument("--mode", choices=["train","test","both"], default="test",
                   help="Which split to process")
    p.add_argument("--classes", type=str, default="all",
                   help="Comma-separated list of original class labels to include, or 'all' for no filtering")
    p.add_argument("--save-dir", type=str, default="./random/",
                   help="Directory to save the t-SNE plots")
    p.add_argument("--concat-embs", type=int, default=1,
                   help="Concatenate embeddings of different bands, if not, do average")
    p.add_argument("--embs_path", type=str, default="embs")
    return p.parse_args()

def add_embedding_g(example):
    example["embedding_g"] = np.random.rand(len(example), 256)
    return example


def add_embedding_r(example):
    example["embedding_r"] = np.random.rand(len(example), 256)
    return example


# dataset = dataset.map(add_length)


def gen_emb(len):
    r = [np.random.rand(256) for _ in range(len)]
    return r

def main():
    """
    Load dataset, remap labels, filter class subset, cluster splits, and plot t-SNE.
    """
    args = parse_args()
    random.seed(args.random_state)
    np.random.seed(args.random_state)

    if args.save_dir:
        pathlib.Path(args.save_dir).mkdir(parents=True, exist_ok=True)
        input_emb_name = pathlib.Path(args.dataset_dir).name
        result_dir = f"s{args.random_state}p{args.perplexity}m{args.mode}c{args.concat_embs}i{input_emb_name}"
        pathlib.Path(args.save_dir, result_dir).mkdir(parents=True, exist_ok=True)
        result_dir = os.path.join(args.save_dir, result_dir)

    else:
        result_dir = None

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)

    file_handler = logging.FileHandler(os.path.join(result_dir, 'log.txt'))
    file_handler.setLevel(logging.INFO)
    logger.addHandler(file_handler)
    print(f"Saving results to {result_dir}")
    logger.info(f"Saving results to {result_dir}")


    # Step 1: Load dataset
    print("Step 1: Loading dataset...")
    logger.info("Step 1: Loading dataset...")

    ds = load_from_disk(args.dataset_dir)
    dataset_real = ds

    dataset_real['train'] = dataset_real['train'].add_column("embeddings_g", gen_emb(len(ds['train'])))
    dataset_real['test'] = dataset_real['test'].add_column("embeddings_g", gen_emb(len(ds['test'])))
    dataset_real['validation'] = dataset_real['validation'].add_column("embeddings_g", gen_emb(len(ds['validation'])))

    dataset_real['train'] = dataset_real['train'].add_column("embeddings_r", gen_emb(len(ds['train'])))
    dataset_real['test'] = dataset_real['test'].add_column("embeddings_r", gen_emb(len(ds['test'])))
    dataset_real['validation'] = dataset_real['validation'].add_column("embeddings_r", gen_emb(len(ds['validation'])))

    a = dataset_real['train']['embeddings_g'][0]

    dataset_real.save_to_disk(f"{args.embs_path}/random_embs_on_raw_train_test_split")


if __name__ == '__main__':
    main()