import sys
from datetime import datetime

import torch as t
import wandb
from loguru import logger

from auto_encoder import debug
from auto_encoder.config import AutoEncoderConfig
from auto_encoder.config_enums import AutoEncoderType
from auto_encoder.training.sae_trainer import SAETrainer

if __name__ == "__main__":
    logger.remove()
    logger.add(sys.stdout, level="DEBUG" if debug else "INFO")

    wandb.login()

    logger.info(f"DEBUG: {debug}")

    cuda_device_count = t.cuda.device_count()
    logger.info(f"Number of GPUs: {cuda_device_count}")

    ### GET USER INPUT FOR CONFIG
    ae_type_str = input("Enter autoencoder type (vanilla, topk, mutual_choice, jump): ")
    ae_type = AutoEncoderType[ae_type_str.strip().upper()]

    num_steps_str = input(
        f"Enter number of steps (default {AutoEncoderConfig.num_total_steps}): "
    )
    num_steps = int(num_steps_str) if num_steps_str else AutoEncoderConfig.num_total_steps

    salient_experiment_feature = input("Enter salient experiment feature: ").strip().lower()

    save_best_model_str = input("Save best model? (y/n): ")
    save_best_model = True if save_best_model_str.strip().lower() == "y" else False

    ae_config = AutoEncoderConfig(
        autoencoder_type=ae_type,
        num_total_steps=num_steps,
    )

    logger.info(ae_config)

    current_date_time = datetime.now()
    formatted_date_time = current_date_time.strftime("%m-%d %H:%M")

    wandb_run_name = f"{formatted_date_time}"
    if salient_experiment_feature:
        wandb_run_name += f" | {salient_experiment_feature}"
    wandb_run_name += f" | {ae_type_str}"
    # wandb_run_name += f" | {ae_config.num_features} features"
    wandb_run_name += f" | {ae_config.num_total_steps} steps"
    # wandb_run_name += f" | {ae_config.router_initialisation_strategy.value} router init"

    sae_trainer = SAETrainer(ae_config=ae_config, use_wandb=True)

    if debug or cuda_device_count == 1:
        save_final_model = True if num_steps >= 10_000 else False
        # compile_model = True if num_steps >= 20_000 else False
        sae_trainer.train(
            alert_on_success=False,
            wandb_run_name=wandb_run_name,
            save_final_model=save_final_model,
            save_best_model=save_best_model,
            #     compile_model=compile_model,
        )
    else:
        raise ValueError(f"Unsupported number of GPUs {cuda_device_count}")
