import torch
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms
from PIL import Image
import cv2
import numpy as np


import torch
from torch import nn
from torch.utils.data import DataLoader
from torch import nn, Tensor
import torchvision
from torchvision import datasets,transforms
from torchvision.transforms import ToTensor
from torch.nn.parallel import DataParallel
from typing import Any, Callable, List, Optional, Type, Union
from torchvision.utils import _log_api_usage_once

def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
    """3x3 convolution with padding"""
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        groups=groups,
        bias=False,
        dilation=dilation,
    )
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):

    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1

        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride


    def forward(self, x: Tensor) -> Tensor:

        identity = x


        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        #ReLU should not be placed in here -> cause error

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)



        return out
class Bottleneck(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion: int = 4

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.0)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

        self.dropout = nn.Dropout(0.5)

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.dropout(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out
class ResNet(nn.Module):
    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        num_classes,
        zero_init_residual: bool = False,
        groups: int = 1,
        width_per_group: int = 64,
        replace_stride_with_dilation: Optional[List[bool]] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()
        _log_api_usage_once(self)
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                "replace_stride_with_dilation should be None "
                f"or a 3-element tuple, got {replace_stride_with_dilation}"
            )
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck) and m.bn3.weight is not None:
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
                elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]


    def _make_layer(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        planes: int,
        blocks: int,
        stride: int = 1,
        dilate: bool = False,
    ) -> nn.Sequential:
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(
            block(
                self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
            )
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    norm_layer=norm_layer,
                )
            )

        return nn.Sequential(*layers)

    def _forward_impl(self, x: Tensor) -> Tensor:
        

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)


        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)
def _resnet(
    block: Type[Union[BasicBlock, Bottleneck]],
    layers: List[int],
    weights,
    progress: bool,
    num_classes,
):
    """
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
    """

    model = ResNet(block, layers, num_classes = num_classes)

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

    return model
def resnet18(num_classes, weights = None, progress: bool = True) -> ResNet:
    """ResNet-18 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__.

    Args:
        weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.ResNet18_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.ResNet18_Weights
        :members:
    """
    #weights = ResNet18_Weights.verify(weights)

    return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress ,num_classes)


import torchmetrics

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

from MuchInjection.CIFAR100.CIFAR100_Dataloader import CIFAR100DataLoader
from MuchInjection.FER.FER2013_Dataloader import FER2013DataLoader
from MuchInjection.AffectNet.AffectNet_Dataloader import AffectNetDataLoader

import argparse


parser = argparse.ArgumentParser()

parser.add_argument('--dataset', default="AffectNet",type=str)
parser.add_argument('--batch_size', default=224, type=int)
parser.add_argument('--lr', default=0.0001, type=float)
parser.add_argument('--model', default="resnet18", type=str)
parser.add_argument('--pretrainedPath', type=str)

opt = parser.parse_args()



#Dataset 정의
if opt.dataset == "CIFAR100":
    train_dataloader, test_dataloader = CIFAR100DataLoader(opt.batch_size)
    num_classes = 100
elif opt.dataset == "FER2013":
    trainloader, PublicTestloader, PrivateTestloader = FER2013DataLoader(opt.batch_size)
    train_dataloader = trainloader
    test_dataloader = PrivateTestloader
    num_classes = 7
elif opt.dataset == "AffectNet":
    train_dataloader, test_dataloader =  AffectNetDataLoader(opt.batch_size)
    num_classes = 8

# 학습에 사용할 CPU나 GPU, MPS 장치를 얻습니다.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

if opt.model == "resnet18":
    model = resnet18(num_classes).to(device)
elif opt.model == "resnet34":
    model = resnet34(num_classes).to(device)

try: 
    model.load_state_dict(torch.load(opt.pretrainedPath))

except :
    saved_state_dict = torch.load(opt.pretrainedPath)

    new_state_dict = {k.replace("module.", ""): v for k, v in saved_state_dict.items()}

    # Load it back to the model
    model.load_state_dict(new_state_dict)

# Load ResNet-18 pre-trained model + higher level layers
model.eval()

# Preprocess input image
img_path = '/VirtualSanghwa/pythonProject/1505.jpg'  # Replace with your image path
img = Image.open(img_path).convert('RGB')
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_img = transform(img).unsqueeze(0).to(device)

# Forward pass
model.zero_grad()
output = model(input_img)

# Class-specific backpropagation (e.g., for class "243" which is "bull mastiff" in ImageNet)
_, preds = output.max(dim=1)  # Get prediction
output[:, preds.item()].backward()

# Get the gradients of the last convolutional layer
gradients = model.layer4[1].conv2.weight.grad

# Global average pooling to get the weights for each channel
weights = gradients.mean(dim=[2, 3], keepdim=True)

# Extract the last convolutional feature maps (layer4 in ResNet-18)
feature_maps = dict()
def forward_hook(module, input, output):
    feature_maps["value"] = output

hook = model.layer4[1].conv2.register_forward_hook(forward_hook)

# Forward pass again to get the feature maps
model(input_img)
hook.remove()

# Get feature maps and weights
feature_maps = feature_maps["value"]
weights = weights.detach()

# Grad-CAM
cam = (feature_maps * weights).sum(dim=1, keepdim=True).relu().squeeze().cpu().detach().numpy()
cam = cv2.resize(cam[0], (img.width, img.height))
cam = np.maximum(cam, 0)  # Relu
cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam))  # Normalize between 0 and 1

# Display
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
heatmap = np.float32(heatmap) / 255
superimposed_img = heatmap + np.float32(img.resize((224, 224))) / 255
superimposed_img = np.uint8(255 * superimposed_img / np.max(superimposed_img))

# Save or plot the Grad-CAM and superimposed images
cv2.imwrite('/VirtualSanghwa/pythonProject/grad_cam.jpg', np.uint8(255 * cam))
cv2.imwrite('/VirtualSanghwa/pythonProject/superimposed.jpg', cv2.cvtColor(np.uint8(255 * superimposed_img), cv2.COLOR_BGR2RGB))
