import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import torchvision.datasets as datasets
import torchvision.transforms as transforms

from .resnet import get_resnet_model
from .resnet_gn import get_resnet_model_gn
from codes.components.utils import log_dict


class Cutout(object):
    def __init__(self, length):
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = torch.ones(h, w, dtype=torch.bool)

        y = torch.randint(h, (1,)).item()
        x = torch.randint(w, (1,)).item()

        y1 = np.clip(y - self.length // 2, 0, h)
        y2 = np.clip(y + self.length // 2, 0, h)
        x1 = np.clip(x - self.length // 2, 0, w)
        x2 = np.clip(x + self.length // 2, 0, w)

        mask[y1: y2, x1: x2] = False
        img[:, mask == False] = 0

        return img


def get_resnet20(use_cuda=False, gn=False):
    if gn:
        print("Using group normalization")
        return get_resnet_model_gn(
            model="resnet20", version=1, dtype="fp32", num_classes=10, use_cuda=use_cuda
        )

    print("Using Batch normalization")
    return get_resnet_model(
        model="resnet20", version=1, dtype="fp32", num_classes=10, use_cuda=use_cuda
    )


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv_layers = nn.Sequential(
            # (3,32×32)-C(64)-R-B
            nn.Conv2d(3, 64, kernel_size=5, padding=0, stride=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),

            # C(64)-R-B-M-D
            nn.Conv2d(64, 64, kernel_size=5, padding=0, stride=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2),
            nn.Dropout(0.25),

            # C(128)-R-B
            nn.Conv2d(64, 128, kernel_size=5, padding=0, stride=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),

            # C(128)-R-B-M-D
            nn.Conv2d(128, 128, kernel_size=5, padding=0, stride=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(kernel_size=2),
            nn.Dropout(0.25),
        )

        self.fc_layers = nn.Sequential(
            # L(128)-R-D
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.25),

            # L(10)
            nn.Linear(128, 10),

            # S
            # nn.LogSoftmax(dim=1)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc_layers(x)
        return x


def cifar10(
        data_dir,
        train,
        download,
        batch_size,
        shuffle=False,
        sampler_callback=None,
        dataset_cls=datasets.CIFAR10,
        drop_last=True,
        **loader_kwargs
):
    if sampler_callback is not None and shuffle is not False:
        raise ValueError

    cifar10_stats = {
        "mean": (0.4914, 0.4822, 0.4465),
        "std": (0.2023, 0.1994, 0.2010),
    }

    if train:
        transform = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                # transforms.RandomCrop(32, padding=4),
                # transforms.RandomRotation(15),  # random rotation
                # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # color jitter
                transforms.ToTensor(),
                transforms.Normalize(cifar10_stats["mean"], cifar10_stats["std"]),
                # Cutout(length=int(32 * 0.25))  # cutout regularization
            ]
        )
    else:
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(cifar10_stats["mean"], cifar10_stats["std"]),
            ]
        )

    dataset = dataset_cls(
        root=data_dir, train=train, download=download, transform=transform
    )

    dataset.targets = torch.LongTensor(dataset.targets)

    sampler = sampler_callback(dataset) if sampler_callback else None
    log_dict(
        {
            "Type": "Setup",
            "Dataset": "cifar10",
            "data_dir": data_dir,
            "train": train,
            "download": download,
            "batch_size": batch_size,
            "shuffle": shuffle,
            "sampler": sampler.__str__() if sampler else None,
        }
    )
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        drop_last=drop_last,
        **loader_kwargs,
    )
