import os
from typing import Any, Literal

import torch
import torch.nn as nn

from core.models import *

PRETRAINED_MODEL_CLASSES = {
    'resnet18': ResNet18,
    'resnet34': ResNet34,
    'vgg13_bn': vgg13_bn,
    'ViT': ViT,
    'SimpleViT': SimpleViT,
}


PRETRAINED_MODEL_PATH = {
    'cifar10': '../.local/share/cifar10/pretrained_models/',
    'cifar100': '../.local/share/cifar100/pretrained_models/',
    'tinyimagenet': '../.local/share/tinyimagenet/pretrained_models/'
}

class ExperimentModel:
    @classmethod
    def model_list(cls):
        chosen_models = [
            'resnet18',
            'resnet34',
            'vgg13_bn',
            'ViT',
            'SimpleViT',
            'CCT'
        ]
        return chosen_models
    
    def __init__(
        self, 
        model_name: str, 
        dataset_name: Literal['cifar10', 'cifar100', 'tinyimagenet'], 
        pretrained: bool = False
        ):
        self.model_name = model_name
        self.dataset_name = dataset_name
        self.pretrained = pretrained
        
        if self.model_name not in PRETRAINED_MODEL_CLASSES.keys():
            raise ValueError(f'Model {self.model_name} not supported')
        
        if self.dataset_name not in PRETRAINED_MODEL_PATH.keys():
            raise ValueError(f"Dataset {self.dataset_name} not supported")
        
        num_classes = {
            'cifar10': 10,
            'cifar100': 100,
            'tinyimagenet': 200
        }[dataset_name]
        
        self.model = PRETRAINED_MODEL_CLASSES[model_name](num_classes=num_classes)
        
        if self.pretrained:
            model_path = os.path.join('../codes/BackdoorAttack/BackdoorBox/pretrained_models', self.dataset_name, self.model_name, 'benign_model.pt')

            model_dict = torch.load(model_path, map_location='cpu')
            
            self.model.load_state_dict(model_dict)
        
        
    def __call__(self) -> nn.Module:
        return self.model
        