import torch
import torch.nn.functional as F
from ffcv.loader import Loader, OrderOption
from ffcv.transforms import ToTensor, ToDevice, Squeeze, NormalizeImage, ToTorchImage
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder
from ffcv.fields.basics import IntDecoder
from torch.amp import GradScaler
import torchvision.models as models
import numpy as np
from pathlib import Path

class BlurPoolConv2d(torch.nn.Module):
    """
    Note to future XXXX-1.

    some_blurpool_layer.conv is the same as some_blurpool_layer when it comes to extracting hook outputs
    look closely, the last thing in the forward pass is self.conv.

    No idea why the hook output is None for some_blurpool_layer.conv,
    but that is fine. Just set it to some_blurpool_layer and you'll get the equivalent result
    """
    def __init__(self, conv):
        super().__init__()
        default_filter = torch.tensor([[[[1, 2, 1], [2, 4, 2], [1, 2, 1]]]]) / 16.0
        filt = default_filter.repeat(conv.in_channels, 1, 1, 1)
        self.conv = conv
        self.register_buffer('blur_filter', filt)

    def forward(self, x):
        blurred = F.conv2d(x, self.blur_filter, stride=1, padding=(1, 1),
                           groups=self.conv.in_channels, bias=None)
        return self.conv.forward(blurred)
    
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255
IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255
DEFAULT_CROP_RATIO = 224/256

def create_model_and_scaler(arch, pretrained, distributed, use_blurpool, gpu: int):
        scaler = GradScaler()
        model = getattr(models, arch)(pretrained=pretrained)
        def apply_blurpool(mod: torch.nn.Module):
            for (name, child) in mod.named_children():
                if isinstance(child, torch.nn.Conv2d) and (np.max(child.stride) > 1 and child.in_channels >= 16): 
                    setattr(mod, name, BlurPoolConv2d(child))
                else: apply_blurpool(child)
        if use_blurpool: apply_blurpool(model)

        model = model.to(memory_format=torch.channels_last)
        model = model.to(gpu)

        if distributed:
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])

        return model, scaler

def create_val_loader(val_dataset, num_workers, batch_size,
                          resolution, distributed, gpu: int, shuffled = True):
        this_device = f'cuda:{gpu}'
        val_path = Path(val_dataset)
        assert val_path.is_file()
        res_tuple = (resolution, resolution)
        cropper = CenterCropRGBImageDecoder(res_tuple, ratio=DEFAULT_CROP_RATIO)
        image_pipeline = [
            cropper,
            ToTensor(),
            ToDevice(torch.device(this_device), non_blocking=True),
            ToTorchImage(),
            NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float16)
        ]

        label_pipeline = [
            IntDecoder(),
            ToTensor(),
            Squeeze(),
            ToDevice(torch.device(this_device),
            non_blocking=True)
        ]

        if shuffled:
            order = OrderOption.RANDOM
        else:
            order = OrderOption.SEQUENTIAL
        loader = Loader(val_dataset,
                        batch_size=batch_size,
                        num_workers=num_workers,
                        order=order,
                        drop_last=False,
                        pipelines={
                            'image': image_pipeline,
                            'label': label_pipeline
                        },
                        distributed=distributed, seed = 0)
        return loader
