import argparse
import sys
import os

# Add parent directory to path to import data module
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data import get_dataset_loader
from src.training_utils import train


def parse_args():
    parser = argparse.ArgumentParser(description="Train or evaluate diffusion models")

    parser.add_argument(
        "--model-type",
        type=str,
        choices=["unet", "mlp", "discretized_mlp"],
        default="unet",
        help="Type of model to train (unet, mlp, or discretized_mlp)",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["mnist", "fashion_mnist", "cifar10", "ffhq", "celeba_hq", "afhq"],
        default="cifar10",
        help="Dataset to use for training",
    )
    parser.add_argument(
        "--subset-size",
        type=int,
        default=-1,
        help="Subset size for training (default: -1)",
    )
    parser.add_argument(
        "--use-attention",
        action="store_true",
        help="Use attention mechanism in the model (only for UNet)",
    )
    parser.add_argument(
        "--gpu", type=int, default=0, help="GPU device number to use (default: 0)"
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=200,
        help="Number of training epochs (default: 200)",
    )
    parser.add_argument(
        "--random-seed",
        type=int,
        default=42,
        help="Random seed for reproducibility (default: 42)",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=80,
        help="Batch size for training (default: 80)",
    )
    parser.add_argument(
        "--checkpoint-freq",
        type=int,
        default=20,
        help="Save checkpoint every N epochs (default: 20)",
    )
    parser.add_argument(
        "--sample-freq",
        type=int,
        default=20,
        help="Sample images every N epochs (default: 10)",
    )
    parser.add_argument(
        "--injected-signal-power",
        type=float,
        default=0.0,
        help="Injected signal power (default: 0.0)",
    )
    # Dataset-specific arguments
    parser.add_argument(
        "--ffhq-dir",
        type=str,
        default="data/ffhq_70k",
        help="Directory containing FFHQ dataset",
    )
    parser.add_argument(
        "--celeba-dir",
        type=str,
        default="data/celebahq-resized-256x256/versions/1/celeba_hq_256",
        help="Directory containing CelebA-HQ dataset",
    )
    parser.add_argument(
        "--afhq-dir",
        type=str,
        default="./data/afhq",
        help="Directory containing AFHQ dataset",
    )

    return parser.parse_args()


def get_dataset_config(dataset_name: str) -> dict:
    """Get dataset-specific configuration."""
    configs = {
        "mnist": {
            "img_size": 28,
            "in_channels": 1,
            "out_channels": 1,
        },
        "fashion_mnist": {
            "img_size": 28,
            "in_channels": 1,
            "out_channels": 1,
        },
        "cifar10": {
            "img_size": 32,
            "in_channels": 3,
            "out_channels": 3,
        },
        "ffhq": {
            "img_size": 32,
            "in_channels": 3,
            "out_channels": 3,
        },
        "celeba_hq": {
            "img_size": 64,
            "in_channels": 3,
            "out_channels": 3,
        },
        "afhq": {
            "img_size": 64,
            "in_channels": 3,
            "out_channels": 3,
        },
    }
    return configs[dataset_name]


def get_model_architecture(img_size: int, model_type: str, use_attention: bool) -> dict:
    """Get model architecture configuration based on image size."""
    if model_type == "unet":
        # For 28x28 (MNIST, FashionMNIST)
        if img_size == 28:
            return {
                "channel": 64,  # Reduced from 128 to save memory
                "channel_mult": [1, 2, 2],  # Only 3 downsamples: 28->14->7->3
                "attn": [1] if use_attention else [],
                "num_res_blocks": 2,
            }
        # For 32x32 (CIFAR10, FFHQ)
        elif img_size == 32:
            return {
                "channel": 128,
                "channel_mult": [1, 2, 2, 2],  # 32->16->8->4->2
                "attn": [1] if use_attention else [],
                "num_res_blocks": 2,
            }
        # For 64x64 (CelebA-HQ, AFHQ)
        elif img_size == 64:
            return {
                "channel": 128,
                "channel_mult": [1, 2, 3, 4],  # 64->32->16->8->4
                "attn": [2] if use_attention else [],
                "num_res_blocks": 2,
            }
        else:
            raise ValueError(f"Unsupported image size: {img_size}")
    else:  # mlp or discretized_mlp
        return {
            "growth_rate": 512 * 8,
            "num_blocks": 12,
            "model_variant": model_type,
        }


if __name__ == "__main__":
    args = parse_args()

    # Get dataset-specific configuration
    dataset_config = get_dataset_config(args.dataset)

    # Base configuration
    modelConfig = {
        "epoch": args.epochs,
        "batch_size": args.batch_size,
        "T": 1000,
        "random_seed": args.random_seed,
        "eval_random_seed": args.random_seed,
        "subset_size": args.subset_size,
        "dropout": 0.15,
        "lr": 1e-4,
        "multiplier": 2.0,
        "beta_1": 1e-4,
        "beta_T": 0.02,
        "img_size": dataset_config["img_size"],
        "in_channels": dataset_config["in_channels"],
        "out_channels": dataset_config["out_channels"],
        "grad_clip": 1.0,
        "device": f"cuda:{args.gpu}",
        "dataset_root": "data/",
        "checkpoint_freq": args.checkpoint_freq,
        "use_wandb": True,
        "sample_freq": args.sample_freq,
        "training_load_weight": None,
        "save_weight_dir": "./Checkpoints/",
        "test_load_weight": "ckpt_199_.pt",
        "sampled_dir": "./SampledImgs/",
        "sampledNoisyImgName": "NoisyNoGuidenceImgs.png",
        "sampledImgName": "SampledNoGuidenceImgs.png",
        "nrow": 8,
        "model_type": args.model_type,
        "injected_signal_power": args.injected_signal_power,
        "dataset_name": args.dataset,
    }

    # Add model-specific configurations
    model_arch = get_model_architecture(
        dataset_config["img_size"], args.model_type, args.use_attention
    )
    modelConfig.update(model_arch)

    # Get the dataloader
    dataloader = get_dataset_loader(
        dataset_name=args.dataset,
        num_images=args.subset_size if args.subset_size > 0 else None,
        batch_size=args.batch_size,
        ffhq_dir=args.ffhq_dir,
        celeba_dir=args.celeba_dir,
        afhq_dir=args.afhq_dir,
    )

    train(modelConfig, dataloader)
