from PIL import Image
import numpy as np
from scipy.ndimage import convolve

import torch.nn.functional as F
import torch
import pickle
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim
from torch.optim.lr_scheduler import LRScheduler

# from augmentation.cutout import Cutout, TransformCutout

# aug_func = torch.nn.Sequential(
#    transforms.RandomCrop(32, 4),
#    #Cutout(n_holes=1, length=4, p=1.0),
#    transforms.RandomHorizontalFlip(),
# )

transform_train = transforms.Compose([
    transforms.ToTensor(),
])

# transform_train_aug = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.RandomCrop(64, padding=4),
#     #TransformCutout(n_holes=1, length=6),
#     transforms.RandomHorizontalFlip(),
# ])

transform_train_aug = transforms.Compose([
    transforms.ToTensor(),
    # TransformCutout(n_holes=1, length=8, p=1.0),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
])

transform_train_aug_atm = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomErasing(p=0.25)
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])


class TINdata(data.Dataset):
    def __init__(self, data, labels, transform):
        self.data = data
        self.transform = transform
        self.labels = labels

    def __getitem__(self, index):
        img = Image.fromarray((self.data[index].transpose(1, 2, 0).astype(np.uint8)))
        label = self.labels[index]
        img = self.transform(img)

        return img, label

    def __len__(self):
        return len(self.labels)


class CSKDTINdata(data.Dataset):
    def __init__(self, data, labels, transform, num_classes=200):
        self.data = data
        self.transform = transform
        self.labels = labels
        self.ctab = [np.zeros(1) for _ in range(num_classes)]
        for i in range(num_classes):
            self.ctab[i], = np.where(self.labels == i)
        #print(self.ctab)

    def __getitem__(self, index):
        img = Image.fromarray((self.data[index].transpose(1, 2, 0).astype(np.uint8)))
        label = self.labels[index]
        img = self.transform(img)

        idx = np.random.randint(len(self.ctab[label]) - 1)
        idx = self.ctab[label][idx]
        img_kd = Image.fromarray((self.data[idx].transpose(1, 2, 0).astype(np.uint8)))
        label_kd = self.labels[idx]
        img_kd = self.transform(img_kd)

        return img, label, img_kd, label_kd

    def __len__(self):
        return len(self.labels)


class DistillTINdata(data.Dataset):
    def __init__(self, data, confs, conf_labels, labels, transform):
        self.data = data
        self.transform = transform
        self.confs = confs
        self.conf_labels = conf_labels
        self.labels = labels

    def __getitem__(self, index):
        img = Image.fromarray((self.data[index].transpose(1, 2, 0).astype(np.uint8)))
        label = self.labels[index]
        img = self.transform(img)
        conf = self.confs[index]
        conf_label = self.conf_labels[index]

        return img, conf, conf_label, label

    def __len__(self):
        return len(self.labels)


class WarmUpLR(LRScheduler):
    """warmup_training learning rate scheduler
    Args:
        optimizer: optimzier(e.g. SGD)
        total_iters: totoal_iters of warmup phase
    """

    def __init__(self, optimizer, total_iters, last_epoch=-1):
        self.total_iters = total_iters
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        """we will use the first m batches, and set the learning
        rate to base_lr * m / total_iters
        """
        return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]


class ModelwNorm(nn.Module):
    def __init__(self, model):
        super(ModelwNorm, self).__init__()
        self.model = model
        self.mean = torch.tensor([0.507, 0.487, 0.441]).view(1, 3, 1, 1)
        self.std = torch.tensor([0.267, 0.256, 0.276]).view(1, 3, 1, 1)

    def forward(self, x):
        m, s = self.mean.to(x.device), self.std.to(x.device)
        return self.model((x - m) / s)


def gaussian_kernel(size, sigma=0.5):
    """Generate a Gaussian kernel."""
    x, y = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1]
    g = np.exp(-(x ** 2 + y ** 2) / (2 * sigma ** 2))
    return g / g.sum()


def apply_gaussian_blur(image, kernel_size=5, sigma=1.0):
    """Apply Gaussian blur to an image using a Gaussian kernel."""
    kernel = gaussian_kernel(kernel_size, sigma)
    if image.ndim == 3:  # Color image
        blurred_image = np.zeros_like(image)
        for i in range(3):  # Apply the kernel to each channel
            blurred_image[:, :, i] = convolve(image[:, :, i], kernel)
    else:  # Grayscale image
        blurred_image = convolve(image, kernel)
    return blurred_image
