import copy
import os
from pathlib import Path

import torch
from timm.models import create_model
from torchvision.models import get_model

from models import ivpt_vit_bb, ivptnet_vit_bb
from models.individual_landmark_vit import IndividualLandmarkViT
from utils import load_state_dict_ivpt


def load_model_arch(args, num_cls):
    """
    Function to load the model
    :param args: Arguments from the command line
    :param num_cls: Number of classes in the dataset
    :return:
    """
    if 'resnet' in args.model_arch:
        num_layers_split = [int(s) for s in args.model_arch if s.isdigit()]
        num_layers = int(''.join(map(str, num_layers_split)))
        if num_layers >= 100:
            timm_model_arch = args.model_arch + ".a1h_in1k"
        else:
            timm_model_arch = args.model_arch + ".a1_in1k"

    if "resnet" in args.model_arch and args.use_torchvision_resnet_model:
        weights = "DEFAULT" if args.pretrained_start_weights else None
        base_model = get_model(args.model_arch, weights=weights)
    elif "resnet" in args.model_arch and not args.use_torchvision_resnet_model:
        if args.eval_only:
            base_model = create_model(
                timm_model_arch,
                pretrained=args.pretrained_start_weights,
                num_classes=num_cls,
                output_stride=args.output_stride,
            )
        else:
            base_model = create_model(
                timm_model_arch,
                pretrained=args.pretrained_start_weights,
                drop_path_rate=args.drop_path,
                num_classes=num_cls,
                output_stride=args.output_stride,
            )

    elif "convnext" in args.model_arch:
        if args.eval_only:
            base_model = create_model(
                args.model_arch,
                pretrained=args.pretrained_start_weights,
                num_classes=num_cls,
                output_stride=args.output_stride,
            )
        else:
            base_model = create_model(
                args.model_arch,
                pretrained=args.pretrained_start_weights,
                drop_path_rate=args.drop_path,
                num_classes=num_cls,
                output_stride=args.output_stride,
            )
    elif "patch" in args.model_arch:
        if args.eval_only:
            base_model = create_model(
                args.model_arch,
                pretrained=args.pretrained_start_weights,
                img_size=args.image_size,
            )
        else:
            base_model = create_model(
                args.model_arch,
                pretrained=args.pretrained_start_weights,
                drop_path_rate=args.drop_path,
                img_size=args.image_size,
            )
        vit_patch_size = base_model.patch_embed.proj.kernel_size[0]
        if args.image_size % vit_patch_size != 0:
            raise ValueError(f"Image size {args.image_size} must be divisible by patch size {vit_patch_size}")
    else:
        raise ValueError('Model not supported.')

    return base_model


def init_ivpt_model(base_model, args, num_cls):
    """
    Function to initialize the model
    :param base_model: Base model
    :param args: Arguments from the command line
    :param num_cls: Number of classes in the dataset
    :return:
    """
    # Initialize the network
    if 'patch' in args.model_arch:
        model = IndividualLandmarkViT(base_model, num_classes=num_cls,
                                      part_dropout=args.part_dropout,
                                      modulation_type=args.modulation_type, gumbel_softmax=args.gumbel_softmax,
                                      gumbel_softmax_temperature=args.gumbel_softmax_temperature,
                                      gumbel_softmax_hard=args.gumbel_softmax_hard,
                                      classifier_type=args.classifier_type,
                                      noise_variance=args.noise_variance, n_pro=args.n_pro)
    else:
        raise ValueError('Model not supported.')

    return model


def load_model_ivpt(args, num_cls):
    """
    Function to load the model
    :param args: Arguments from the command line
    :param num_cls: Number of classes in the dataset
    :return:
    """
    base_model = load_model_arch(args, num_cls)
    model = init_ivpt_model(base_model, args, num_cls)

    return model


def ivpt_vit(pretrained=True, model_dataset="cub", k=8, model_url="", img_size=224, num_cls=200):
    """
    Function to load the PDiscoFormer model with ViT backbone
    :param pretrained: Boolean flag to load the pretrained weights
    :param model_dataset: Dataset for which the model is trained
    :param k: Number of unsupervised landmarks the model is trained on
    :param model_url: URL to load the model weights from
    :param img_size: Image size
    :param num_cls: Number of classes in the dataset
    :return: PDiscoFormer model with ViT backbone
    """
    model = ivpt_vit_bb("vit_base_patch14_reg4_dinov2.lvd142m", num_cls=num_cls, k=k, img_size=img_size)
    if pretrained:
        hub_dir = torch.hub.get_dir()
        model_dir = os.path.join(hub_dir, "ivpt_checkpoints", f"ivpt_{model_dataset}")

        Path(model_dir).mkdir(parents=True, exist_ok=True)
        url_path = model_url + str(k) + "_parts_snapshot_best.pt"
        snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
        if 'model_state' in snapshot_data:
            _, state_dict = load_state_dict_ivpt(snapshot_data)
        else:
            state_dict = copy.deepcopy(snapshot_data)
        model.load_state_dict(state_dict, strict=True)
    return model


def ivptnet_vit(pretrained=True, model_dataset="nabirds", k=8, model_url="", img_size=224, num_cls=555):
    """
    Function to load the PDiscoNet model with ViT backbone
    :param pretrained: Boolean flag to load the pretrained weights
    :param model_dataset: Dataset for which the model is trained
    :param k: Number of unsupervised landmarks the model is trained on
    :param model_url: URL to load the model weights from
    :param img_size: Image size
    :param num_cls: Number of classes in the dataset
    :return: PDiscoNet model with ViT backbone
    """
    model = ivptnet_vit_bb("vit_base_patch14_reg4_dinov2.lvd142m", num_cls=num_cls, k=k, img_size=img_size)
    if pretrained:
        hub_dir = torch.hub.get_dir()
        model_dir = os.path.join(hub_dir, "ivpt_checkpoints", f"ivptnet_{model_dataset}")

        Path(model_dir).mkdir(parents=True, exist_ok=True)
        url_path = model_url + str(k) + "_parts_snapshot_best.pt"
        snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
        if 'model_state' in snapshot_data:
            _, state_dict = load_state_dict_ivpt(snapshot_data)
        else:
            state_dict = copy.deepcopy(snapshot_data)
        model.load_state_dict(state_dict, strict=True)
    return model


def ivptnet_resnet101(pretrained=True, model_dataset="nabirds", k=8, model_url="", num_cls=555):
    """
    Function to load the PDiscoNet model with ResNet-101 backbone
    :param pretrained: Boolean flag to load the pretrained weights
    :param model_dataset: Dataset for which the model is trained
    :param k: Number of unsupervised landmarks the model is trained on
    :param model_url: URL to load the model weights from
    :param num_cls: Number of classes in the dataset
    :return: PDiscoNet model with ResNet-101 backbone
    """
    model = ivptnet_resnet_torchvision_bb("resnet101", num_cls=num_cls, k=k)
    if pretrained:
        hub_dir = torch.hub.get_dir()
        model_dir = os.path.join(hub_dir, "ivpt_checkpoints", f"ivptnet_{model_dataset}")

        Path(model_dir).mkdir(parents=True, exist_ok=True)
        url_path = model_url + str(k) + "_parts_snapshot_best.pt"
        snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
        if 'model_state' in snapshot_data:
            _, state_dict = load_state_dict_ivpt(snapshot_data)
        else:
            state_dict = copy.deepcopy(snapshot_data)
        model.load_state_dict(state_dict, strict=True)
    return model
