import torch
import argparse
from training import Trainer, Config

parser = argparse.ArgumentParser(description='Trainer for MixBin Experiments')
parser.add_argument('--experiment_name', type=str, required=True, help="Name of the experiment")
parser.add_argument('--epochs', type=int, default=160, help="Number of epochs to train")
parser.add_argument('--seed', type=int, default=42, help="Fix random seed")
parser.add_argument('--model_compression', type=str, default='mixbin', help="Compression method for the model")
parser.add_argument('--model_name', type=str, default="cifar_dsresnet_20", help="Model architecture")
parser.add_argument('--dataset_name', type=str, default="cifar100", help="Dataset name")
parser.add_argument('--dataset_num_classes', type=int, default=100, help="Number of classes in the dataset")
parser.add_argument('--keep_full_precision', nargs="+", type=int, default=[], help="Layers to keep full precision")
parser.add_argument('--scheduler_milestones',  nargs="+", type=int, default=[80, 120], help="Milestones for the scheduler")
parser.add_argument('--batch_size', type=int, default=256, help="Batch size for training")
args = parser.parse_args()

config = Config(
    experiment_name=args.experiment_name,
    root_dir="./",
    device="cuda", 
    epochs=args.epochs,
    seed=args.seed,
    keep_full_precision=args.keep_full_precision,
    training_batch_size=args.batch_size,
    test_batch_size=2*args.batch_size,
    train_shuffle=True,
    test_shuffle=False,
    save_at_epochs=[],
    num_workers=4,
    clip_grad_norm=1,
    scheduler_type="multi_step",
    scheduler_gamma=0.1,
    scheduler_milestones=args.scheduler_milestones,
    optimizer_type="sgd",
    optimizer_lr=0.2,
    optimizer_weight_decay=0.001,
    dataset_name=args.dataset_name,
    dataset_num_classes=args.dataset_num_classes,
    dataset_order_seed=42,
    model_name=args.model_name,
    model_compression_strategy=args.model_compression,
)

loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.01)
class MixBinTrainer(Trainer):
    def criterion(self, y_pred, y_true):
        return loss_fn(y_pred, y_true)

if __name__ == '__main__':
    trainer = MixBinTrainer(config)
    trainer.fit(resume_training=False)