from models.ResNet import ResNet18, ResNet50, ResNet101
from models.PreActResNet import PreActResNet18
from models.DenseNet import DenseNet121, DenseNet169, DenseNet201
from models.WideResNet import WideResNet40_1
from models.PyramidNet import PyramidNet
from torchvision.models import vit_b_16, resnet152
from torchvision.models import vit_b_16, ViT_B_16_Weights
import os
import torch
from pathlib import Path
import torch.nn as nn
from models.ShakePyramidNet import ShakeDropPyramidNet272


def save_checkpoint(args, save_path, epoch, model, optimizer, scheduler):
    filename = construct_checkpoint_name(args, save_path, epoch)
    checkpoint = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
    }
    torch.save(checkpoint, filename)
    print(f"Checkpoint saved to {filename}")

def load_model(model_name: str, num_classes: int):
    name = model_name.lower()
    models = {
        'resnet18': lambda: ResNet18(num_classes),
        'preactresnet18': lambda: PreActResNet18(num_classes),
        'resnet50': lambda: ResNet50(num_classes),
        'resnet101': lambda: ResNet101(num_classes),
        'resnet152': lambda: _make_resnet152(num_classes),
        'densenet121': lambda: DenseNet121(num_classes),
        'densenet169': lambda: DenseNet169(num_classes),
        'densenet201': lambda: DenseNet201(num_classes),
        'wideresnet40_1': lambda: WideResNet40_1(num_classes),
        'vit-b': lambda: _make_vit_b(num_classes),
        'pyramidnet': lambda: PyramidNet(dataset='cifar100', depth=272, alpha=200, num_classes=num_classes, bottleneck=True),
        'shakepyramidnet': lambda: ShakeDropPyramidNet272(num_classes=num_classes)
    }

    try:
        return models[name]()
    except KeyError:
        raise ValueError(f"Unsupported model architecture: {model_name}")

def _make_resnet152(num_classes: int):
    model = resnet152(weights=None)
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model

def _make_vit_b(num_classes: int):
    model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
    in_features = model.heads.head.in_features
    model.heads.head = nn.Linear(in_features, num_classes)
    return model

def load_checkpoint(filename, model, optimizer, scheduler):
    if os.path.isfile(filename):
        print(f"Loading checkpoint '{filename}'")
        checkpoint = torch.load(filename)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        epoch = checkpoint['epoch']
        print(f"Loaded checkpoint '{filename}' (epoch {epoch})")
        return epoch
    else:
        print(f"No checkpoint found at '{filename}'")
        return None

def construct_checkpoint_name(args, save_path, epoch):
    """
    Construct a checkpoint filename based on args and epoch without if/elif chains.
    """
    base = (f"{args.model}_{args.dataset}_{args.method.lower()}_s{args.seed}_epoch_{epoch}")

    rules = [
        (lambda a: hasattr(a, 'method') and 'bsd' in a.method.lower(),
                                        lambda a: f"c{a.c}_g{a.gamma}"),
        (lambda a: hasattr(a, 'teacher_model') and a.teacher_model and hasattr(a, 'teacher_model_paths') and a.teacher_model_paths,
                                        lambda a: f"KD_{a.temp}"),
        (lambda a: hasattr(a, 'cutmix') and a.cutmix,               lambda a: "cutmix"),
        (lambda a: hasattr(a, 'mixup') and a.mixup,                lambda a: "mixup"),
        (lambda a: hasattr(a, 'cutout') and a.cutout,               lambda a: "cutout"),
        (lambda a: hasattr(a, 'noise_rate') and a.noise_rate != 0,      lambda a: f"n_{a.noise_type}_r{a.noise_rate}_s{a.noise_seed}"),        
        (lambda a: hasattr(a, 'sce_beta') and a.sce_beta != 0,        lambda a: "sce"),
        (lambda a: hasattr(a, 'label_smoothing') and a.label_smoothing != 0,        lambda a: f"ls{a.label_smoothing}"),
    ]

    active_suffixes = [fmt(args) for cond, fmt in rules if cond(args)]

    filename_parts = [base]
    if active_suffixes:
        filename_parts.extend(active_suffixes)

    final_name = "_".join(filename_parts) + ".pth"
    
    return str(Path(save_path) / final_name)
