import hydra
from omegaconf import DictConfig
import torch
from src.data import (
    load_celeba_datasets,
    build_celeba_dataloader,
)
from src.trainer import train_debiased_classifier
from src.spaco_trainer import train_classifier_spaco
from src.models import ClassifierResNet18, AdversaryResNet


@hydra.main(config_path="conf", config_name="config", version_base=None)
def main(cfg: DictConfig):
    torch.manual_seed(cfg.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(cfg.seed)

    device = torch.device(
        "cuda" if torch.cuda.is_available() and not cfg.get("no_cuda", False) else "cpu"
    )
    print(f"Using device: {device}")
    print(f"Algo: {cfg.name}")

    data_dir = hydra.utils.to_absolute_path(cfg.data_dir)

    train_ds, test_ds = load_celeba_datasets(
        root=data_dir,
        target_resolution=(224, 224),
        download=True,
    )

    train_loader = build_celeba_dataloader(
        train_ds,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        drop_last=True,
    )
    test_loader = build_celeba_dataloader(
        test_ds,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        drop_last=False,
    )

    classifier = ClassifierResNet18(
        num_classes=1,
        pretrained=cfg.get("pretrained", True),
        seed=cfg.seed,
    ).to(device)
    adversary = AdversaryResNet(logit_dim=1, seed=cfg.seed).to(device)

    if cfg.name == "no_debias":
        print("\nTraining classifier (no debias)...")
        train_debiased_classifier(
            classifier=classifier,
            adversary=adversary,
            train_dataset=None,
            train_loader=train_loader,
            eval_dataset=None,
            eval_loader=test_loader,
            debias=False,
            device=device,
            num_epochs=cfg.epochs,
            batch_size=cfg.batch_size,
            base_lr=cfg.base_lr,
            prox=cfg.prox_x,
        )
    elif cfg.name == "adv_debias":
        print("\nTraining debiased classifier (adv_debias)...")
        train_debiased_classifier(
            classifier=classifier,
            adversary=adversary,
            train_dataset=None,
            train_loader=train_loader,
            eval_dataset=None,
            eval_loader=test_loader,
            debias=True,
            device=device,
            num_epochs=cfg.epochs,
            batch_size=cfg.batch_size,
            adversary_loss_weight=cfg.adv_loss_weight,
            base_lr=cfg.base_lr,
            prox=cfg.prox_x,
        )
    elif cfg.name == "spaco":
        print("\nTraining debiased classifier (SPACO)...")
        train_classifier_spaco(
            classifier=classifier,
            adversary=adversary,
            train_dataset=None,
            train_loader=train_loader,
            eval_dataset=None,
            eval_loader=test_loader,
            device=device,
            num_epochs=cfg.epochs,
            batch_size=cfg.batch_size,
            beta=cfg.beta,
            base_lr=cfg.base_lr,
            penalty_eps=cfg.penalty_eps,
            penalty_rho0=cfg.penalty_rho0,
            pilot_t=cfg.pilot_t,
            pilot_s=cfg.pilot_s,
            use_storm=cfg.use_storm,
            storm_eta0=cfg.storm_eta0,
            storm_eta_min=cfg.storm_eta_min,
            prox_y=cfg.prox_y,
            prox_x=cfg.prox_x,
            effective_implementation=cfg.effective_implementation,
        )
    else:
        raise ValueError(f"Unknown algo: {cfg.name}")


if __name__ == "__main__":
    main()
