#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights,\
    efficientnet_v2_m, EfficientNet_V2_M_Weights, \
    efficientnet_b1, EfficientNet_B1_Weights, \
    mobilenet_v3_large, MobileNet_V3_Large_Weights, \
    regnet_y_128gf, RegNet_Y_128GF_Weights, \
    vit_h_14, ViT_H_14_Weights, \
    vit_b_16, ViT_B_16_Weights, \
    regnet_y_16gf, RegNet_Y_16GF_Weights
import numpy as np
import skimage
from xai.utils import convert_np_img_to_torch


class FastMnistNet(nn.Module):
    def __init__(self):
        super(FastMnistNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 24, 5, 1)
        self.conv2 = nn.Conv2d(24, 32, 3, 1)
        self.fc1 = nn.Linear(800, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(x, 2)
        x = F.relu(x)

        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = F.relu(x)

        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

class Problem():
    def get_sample(self, index):
        pass

    def get_gradcam_layer(self):
        pass

class FashionMnistProblem(Problem):
    def __init__(self, weights_path, data_path, device = "cuda:0"):
        self.device = device

        network = FastMnistNet()
        network.load_state_dict(torch.load(weights_path, weights_only=True))
        network = network.to(device=device)
        network.eval()
        self.model = nn.Sequential(network, nn.Softmax(1))

        self.test_data = torch.load(data_path / "test_data.pt", map_location=device)
        self.test_labels = torch.load(data_path / "test_labels.pt", map_location=device)

    def get_sample(self, index):
        return self.test_data[index, None, ...], self.test_labels[index]

    def get_gradcam_layer(self):
        return self.model[0].fc2

class ImageNetBaseProblem(Problem):
    def __init__(self, device, network_name):
        self.device = device
        self.network_name = network_name

        model_dict = {
            "ResNet50_V1" : (resnet50, ResNet50_Weights.IMAGENET1K_V1),
            "ResNet50_V2" : (resnet50, ResNet50_Weights.IMAGENET1K_V2),
            "EfficientNet_V2_m_V1" : (efficientnet_v2_m, EfficientNet_V2_M_Weights.IMAGENET1K_V1),
            "EfficientNet_B1_V2" : (efficientnet_b1, EfficientNet_B1_Weights.IMAGENET1K_V2),
            "MobileNet_V3_Large_V2" : (mobilenet_v3_large, MobileNet_V3_Large_Weights.IMAGENET1K_V2),
            "RegNet_y_16gf_V1" : (regnet_y_16gf, RegNet_Y_16GF_Weights.IMAGENET1K_V1),
            "RegNet_y_16gf_V2" : (regnet_y_16gf, RegNet_Y_16GF_Weights.IMAGENET1K_V2),
            "RegNet_y_16gf_SWAG_E2E" : (regnet_y_16gf, RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_E2E_V1),
            "RegNet_y_16gf_SWAG_Linear" : (regnet_y_16gf, RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_LINEAR_V1),
            "ViT_b_16_V1" : (vit_b_16, ViT_B_16_Weights.IMAGENET1K_V1),
            "ViT_b_16_SWAG_E2E" : (vit_b_16, ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1),
            "ViT_b_16_SWAG_Linear" : (vit_b_16, ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1),
            "RegNet_y_128gf_SWAG_E2E" : (regnet_y_128gf, RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_E2E_V1),
            "RegNet_y_128gf_SWAG_Linear" : (regnet_y_128gf, RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_LINEAR_V1),
            "ViT_h_14_SWAG_E2E" : (vit_h_14, ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1),
            "ViT_h_14_SWAG_Linear" : (vit_h_14, ViT_H_14_Weights.IMAGENET1K_SWAG_LINEAR_V1)
            }

        model_tup = model_dict[network_name]
        network = model_tup[0](weights=model_tup[1])
        network = network.to(device=device)
        network.eval()
        self.model = nn.Sequential(network, nn.Softmax(1))

        self.transform =  model_tup[1].transforms()

    def convert_img_for_imshow(self, img):
        np_img = img.detach().cpu().numpy()
        np_img = np.moveaxis(np_img[0,:,:,:], 0, 2)
        mean = np.array(self.transform.mean, dtype=np.float32)
        std = np.array(self.transform.std, dtype=np.float32)
        return np.clip(np_img * std[None, None, :] + mean[None, None, :], 0, 1)

    def normalize_intensity(self, img):
        mean = torch.from_numpy(np.array(self.transform.mean, dtype=np.float32)).to(device=img.device)[None, :, None, None]
        std = torch.from_numpy(np.array(self.transform.std, dtype=np.float32)).to(device=img.device)[None, :, None, None]
        return (img-mean)/std

    def get_gradcam_layer(self):
        if self.network_name.startswith("ResNet50"):
            return self.model[0].layer4

class ImageNetValProblem(ImageNetBaseProblem):
    def __init__(self, data_path, network_name, num_per_class, class_step=1, class_offset=0, device="cuda:0"):
        super().__init__(device, network_name)
        self.data_path = data_path
        self.num_per_class = num_per_class
        self.class_step = class_step
        self.class_offset = class_offset
        self.class_folders = sorted(list(data_path.glob("n*")), key=lambda a : a.name)

    def get_sample(self, index):
        label = (index // self.num_per_class)*self.class_step+self.class_offset

        img_path = sorted(list(self.class_folders[label].glob("ILSVRC2012_val*.JPEG")), key=lambda a : a.name)[index % self.num_per_class]

        img = skimage.util.img_as_float32(skimage.io.imread(img_path))
        img = self.transform(convert_np_img_to_torch(img, ensure_color_channels=True, device=self.device))

        return img, label