import torch
from pathlib import Path
from typing import Union

from .base import BaseModel
from .cnn import get_cnn_model
from .vgg import get_vgg_model
from .resnet import get_res_model
from .fc import get_fc_model
from .utils import get_torchvision_model, split_up_model

from robustbench.utils import load_model
from robustbench.model_zoo.enums import ThreatModel

from data import get_dataset_shape, gradual_domains, corruption_domains
from config import get_config


cf = get_config(["config/path.yaml"])
models_name = ["cnn", "vgg", "resnet", "fc"]

def get_model(data_name, model_name):
    shape, num_classes = get_dataset_shape(data_name)
    if model_name == "cnn":
        return get_cnn_model(shape, num_classes)
    elif model_name == "fc":
        assert data_name == "covertype", "FC model is only supported for covtype dataset."
        return get_fc_model(shape[0], num_classes)
    elif model_name[:3] == "vgg":
        assert shape[1] != -1 and shape[2] != -1, "VGG model requires input shape (channels, height, width)."
        return get_vgg_model(model_name, shape, num_classes)
    elif model_name[:6] == "resnet":
        return get_res_model(model_name, shape[0], num_classes)    
    raise ValueError(f"Model '{model_name}' not supported.")

def get_trained_model(data_name, model_name=None):
    if data_name in gradual_domains:
        if model_name is None:
            model_name = cf.arch[data_name]
        model = get_model(data_name, model_name)
        ckpt_dir = Path(cf.ckpt[data_name][model_name])
        ckpt_dir = ckpt_dir / f"{model.name}.pth"
        print(f"Loading model from {ckpt_dir}")
        try:
            model.load_state_dict(torch.load(ckpt_dir, map_location=torch.device('cpu')))
            return model
        except FileNotFoundError:
            raise FileNotFoundError(f"Model '{model_name}' not found. Need to train from scratch.")
    elif data_name in corruption_domains:
        if data_name in ["cifar10", "cifar100"]:
            model = load_model(cf.arch[data_name], cf.ckpt_dir, data_name, ThreatModel.corruptions)
        elif data_name == "imagenet":
            model, _ = get_torchvision_model(cf.arch[data_name])
        encoder, classifier = split_up_model(model, cf.arch[data_name], data_name)
        return BaseModel(encoder, classifier, cf.arch[data_name])
    raise ValueError(f"Data '{data_name}' not supported.")

def save_model(model, data_name, model_name):
    ckpt_dir = Path(cf.ckpt[data_name][model_name])
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    ckpt_dir = ckpt_dir / f"{model.name}.pth"
    torch.save(model.state_dict(), ckpt_dir)
    print(f"Model {model.name} saved to {ckpt_dir}")

def print_model_info(model):
    print(model)
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params}")
    print(f"Trainable parameters: {trainable_params}")
    encoder_params = sum(p.numel() for p in model.encoder.parameters())
    classifier_params = sum(p.numel() for p in model.classifier.parameters())
    print(f"Encoder parameters: {encoder_params}")
    print(f"Classifier parameters: {classifier_params}")

__all__ = ["get_model", "get_trained_model", "save_model", "print_model_info", "models_name"]


# ------------------------------------------------------------

if __name__ == "__main__":
    model = get_trained_model("imagenet")
    print(model)
    print_model_info(model)


