import torch.nn.functional as F
import torch.nn as nn
from argument import args as sys_args
import torch
import numpy as np
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import torchvision.models as thmodels
from torchvision.models._api import WeightsEnum
from torch.hub import load_state_dict_from_url
from model_utils.models import ConvNet, mobilenetv2_tinyimagenet, mobilenetv2_cifar10

import math

# use 0 to pad "other three picture"
def pad(input_tensor, target_height, target_width=None):
    if target_width is None:
        target_width = target_height
    vertical_padding = target_height - input_tensor.size(2)
    horizontal_padding = target_width - input_tensor.size(3)

    top_padding = vertical_padding // 2
    bottom_padding = vertical_padding - top_padding
    left_padding = horizontal_padding // 2
    right_padding = horizontal_padding - left_padding

    padded_tensor = F.pad(
        input_tensor, (left_padding, right_padding, top_padding, bottom_padding)
    )

    return padded_tensor


def batched_forward(model, tensor, batch_size):
    total_samples = tensor.size(0)

    all_outputs = []

    model.eval()

    with torch.no_grad():
        for i in range(0, total_samples, batch_size):
            batch_data = tensor[i : min(i + batch_size, total_samples)]

            output = model(batch_data)

            all_outputs.append(output)

    final_output = torch.cat(all_outputs, dim=0)

    return final_output


class MultiRandomCrop(torch.nn.Module):
    def __init__(self, num_crop=5, size=224, factor=2):
        super().__init__()
        self.num_crop = num_crop
        self.size = size
        self.factor = factor

    def forward(self, image):
        cropper = transforms.RandomResizedCrop(
            self.size // self.factor,
            ratio=(1, 1),
            antialias=True,
        )
        patches = []
        for _ in range(self.num_crop):
            patches.append(cropper(image))
        return torch.stack(patches, 0)

    def __repr__(self) -> str:
        detail = f"(num_crop={self.num_crop}, size={self.size})"
        return f"{self.__class__.__name__}{detail}"


normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

denormalize = transforms.Compose(
    [
        transforms.Normalize(
            mean=[0.0, 0.0, 0.0], std=[1 / 0.229, 1 / 0.224, 1 / 0.225]
        ),
        transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1.0, 1.0, 1.0]),
    ]
)


def get_state_dict(self, *args, **kwargs):
    kwargs.pop("check_hash")
    return load_state_dict_from_url(self.url, *args, **kwargs)


WeightsEnum.get_state_dict = get_state_dict


def cross_entropy(y_pre, y):
    y_pre = F.softmax(y_pre, dim=1)
    return (-torch.log(y_pre.gather(1, y.view(-1, 1))))[:, 0]


def selector(n, model, images, labels, size, m=5):
    with torch.no_grad():
        # [mipc, m, 3, 224, 224]
        images = images.cuda()
        s = images.shape

        # [mipc * m, 3, 224, 224]
        images = images.permute(1, 0, 2, 3, 4)
        images = images.reshape(s[0] * s[1], s[2], s[3], s[4])

        # [mipc * m, 1]
        labels = labels.repeat(m).cuda()

        # [mipc * m, n_class]
        batch_size = s[0]  # Change it for small GPU memory
        preds = batched_forward(model, pad(images, size).cuda(), batch_size)

        # [mipc * m]
        dist = cross_entropy(preds, labels)

        # [m, mipc]
        dist = dist.reshape(m, s[0])

        n_patch_per_image = math.ceil(n/dist.shape[-1])
        
        # [mipc]
        # [n_patch_per_image, mipc]
        index = torch.argsort(dist, dim=0)[:n_patch_per_image,:]
        # print(f"index_1 shape: {index_1.shape}")
        # print(f"index_1 val: {index_1[0,0]}")
        # index = torch.argmin(dist, 0)
        # print(f"index shape: {index.shape}")
        # print(f"index val: {index[0]}")
        # dist_1 = dist[index_1, torch.arange(s[0])]
        # print(f"dist_1: {dist_1.shape}")
        dist = dist[index, torch.arange(s[0])]

        # [mipc*n_patch_per_image, 3, 224, 224]
        sa = images.shape
        images = images.reshape(m, s[0], sa[1], sa[2], sa[3])
        images = images[index, torch.arange(s[0])]
        images = images.reshape(n_patch_per_image*s[0], sa[1], sa[2], sa[3])
        dist = dist.reshape(n_patch_per_image*s[0])

    indices = torch.argsort(dist, descending=False)[:n]
    torch.cuda.empty_cache()
    images = images[indices]
    # shuffle
    indexes = torch.randperm(images.shape[0])
    return images[indexes].detach()


def mix_images(input_img, out_size, factor, n):
    s = out_size // factor
    remained = out_size % factor
    k = 0
    mixed_images = torch.zeros(
        (n, 3, out_size, out_size),
        requires_grad=False,
        dtype=torch.float,
    )
    h_loc = 0
    for i in range(factor):
        h_r = s + 1 if i < remained else s
        w_loc = 0
        for j in range(factor):
            w_r = s + 1 if j < remained else s
            # print(f"{k * n} : {(k + 1) * n}")
            
            img_part = F.interpolate(
                input_img.data[k * n : (k + 1) * n], size=(h_r, w_r)
            )
            # print(f"shape: {img_part.shape}")
            mixed_images.data[
                0:n,
                :,
                h_loc : h_loc + h_r,
                w_loc : w_loc + w_r,
            ] = img_part
            w_loc += w_r
            k += 1
        h_loc += h_r
    return mixed_images


def load_model(model_name="resnet18", dataset="cifar10", pretrained=True, classes=[]):
    def get_model(model_name="resnet18"):
        if "conv" in model_name:
            if dataset in ["cifar10", "cifar100"]:
                size = 32
            elif dataset in ["tinyimagenet", "imagenet-1k"]:
                size = 64
            elif dataset in ["imagenet-nette", "imagenet-woof", "imagenet-100"]:
                size = 128
            else:
                raise Exception("Unrecognized dataset")

            nclass = len(classes)

            model = ConvNet(
                num_classes=nclass,
                net_norm="batch",
                net_act="relu",
                net_pooling="avgpooling",
                net_depth=int(model_name[-1]),
                net_width=128,
                channel=3,
                im_size=(size, size),
            )
        elif model_name == "resnet18_modified":
            model = thmodels.__dict__["resnet18"](pretrained=False)
            model.conv1 = nn.Conv2d(
                3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
            )
            model.maxpool = nn.Identity()
        elif model_name == "resnet101_modified":
            model = thmodels.__dict__["resnet101"](pretrained=False)
            model.conv1 = nn.Conv2d(
                3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
            )
            model.maxpool = nn.Identity()
        else:
            model = thmodels.__dict__[model_name](pretrained=False)

        return model

    def pruning_classifier(model=None, classes=[]):
        try:
            model_named_parameters = [name for name, x in model.named_parameters()]
            for name, x in model.named_parameters():
                if (
                    name == model_named_parameters[-1]
                    or name == model_named_parameters[-2]
                ):
                    x.data = x[classes]
        except:
            print("ERROR in changing the number of classes.")

        return model

    if not pretrained and dataset == "tinyimagenet" and model_name == "mobilenet_v2_modified":
        return mobilenetv2_tinyimagenet()
    elif not pretrained and dataset == "cifar10" and model_name == "mobilenet_v2_modified":
        return mobilenetv2_cifar10()
    # "imagenet-100" "imagenet-10" "imagenet-first" "imagenet-nette" "imagenet-woof"
    model = get_model(model_name)
    model = pruning_classifier(model, classes)
    if pretrained:
        if dataset in [
            "imagenet-100",
            "imagenet-10",
            "imagenet-nette",
            "imagenet-woof",
            "tinyimagenet",
            "cifar10",
            "cifar100",
        ]:
            checkpoint = torch.load(
                f"/workspace/save/{dataset}_{model_name}.pth", map_location="cpu"
            )
            model.load_state_dict(checkpoint["model"])
        elif dataset in ["imagenet-1k"]:
            if model_name == "efficientNet-b0":
                # Specifically, for loading the pre-trained EfficientNet model, the following modifications are made
                from torchvision.models._api import WeightsEnum
                from torch.hub import load_state_dict_from_url

                def get_state_dict(self, *args, **kwargs):
                    kwargs.pop("check_hash")
                    return load_state_dict_from_url(self.url, *args, **kwargs)

                WeightsEnum.get_state_dict = get_state_dict
            elif "conv" in model_name:
                checkpoint = torch.load(
                    f"/workspace/save/{dataset}_{model_name}.pth", map_location="cpu"
                )
                model.load_state_dict(checkpoint["model"])
            else:
                model = thmodels.__dict__[model_name](pretrained=True)
    elif not pretrained and model_name == "mobilenet_v2":
        return thmodels.get_model("mobilenet_v2", weights=False, num_classes=len(classes))
    return model
