import time
from pathlib import Path
from typing import List, Tuple, Optional, Callable
import torch as ch
import numpy as np
from ffcv.pipeline.allocation_query import AllocationQuery
from ffcv.pipeline.compiler import Compiler

from ffcv.pipeline.operation import Operation
from ffcv.loader import Loader, OrderOption
from ffcv.pipeline.state import State
from ffcv.transforms import ToTensor, ToDevice, Squeeze, NormalizeImage, \
    RandomHorizontalFlip, ToTorchImage
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, \
    RandomResizedCropRGBImageDecoder
from ffcv.fields.basics import IntDecoder
from tqdm import tqdm
import torch.distributed as dist

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255
IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255
DEFAULT_CROP_RATIO = 224/256


imagenet_pca = {
    'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
    'eigvec': np.asarray([
        [-0.5675, 0.7192, 0.4009],
        [-0.5808, -0.0045, -0.8140],
        [-0.5836, -0.6948, 0.4203],
    ])
}


class Lighting(Operation):
    def __init__(self, alphastd=0.1,
                 eigval=imagenet_pca['eigval'],
                 eigvec=imagenet_pca['eigvec']):
        super().__init__()
        self.alphastd = alphastd
        assert eigval.shape == (3,)
        assert eigvec.shape == (3, 3)
        self.eigval = eigval
        self.eigvec = eigvec

    def generate_code(self) -> Callable:
        parallel_range = Compiler.get_iterator()
        alphastd = self.alphastd
        eigval = self.eigval
        eigvec = self.eigvec

        def run(images, dst):
            for i in parallel_range(images.shape[0]):
                rnd = np.random.randn(3) * alphastd
                v = rnd * eigval
                v = v.reshape((3, 1))
                inc = np.dot(eigvec, v).reshape((3,))
                img = np.add(images[i], inc)
                img = np.clip(img, 0, 255)
                dst[i] = img.astype(np.uint64)
            return dst

        run.is_parallel = True
        return run

    def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
        return previous_state, AllocationQuery(previous_state.shape, previous_state.dtype)


def create_train_loader(train_dataset, num_workers, batch_size,
                        distributed, in_memory, portion=None):
    this_device = f'cuda:{dist.get_rank()}' if distributed else 'cuda'
    if distributed:
        batch_size //= dist.get_world_size()
    train_path = Path(train_dataset)
    assert train_path.is_file()

    res = 224
    decoder = RandomResizedCropRGBImageDecoder((res, res))
    image_pipeline: List[Operation] = [
        decoder,
        Lighting(),
        RandomHorizontalFlip(),
        ToTensor(),
        ToDevice(ch.device(this_device), non_blocking=True),
        ToTorchImage(),
        NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float16)
    ]

    label_pipeline: List[Operation] = [
        IntDecoder(),
        ToTensor(),
        Squeeze(),
        ToDevice(ch.device(this_device), non_blocking=True)
    ]

    order = OrderOption.RANDOM if distributed else OrderOption.QUASI_RANDOM
    indices = None
    if portion is not None:
        indices = list(range(portion))
    loader = Loader(train_dataset,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    indices=indices,
                    order=order,
                    os_cache=in_memory,
                    drop_last=True,
                    pipelines={
                        'image': image_pipeline,
                        'label': label_pipeline
                    },
                    distributed=distributed)

    return loader


def create_val_loader(val_dataset, num_workers, batch_size,
                      resolution, distributed, crop_ratio=DEFAULT_CROP_RATIO):
    this_device = f'cuda'
    if distributed:
        batch_size //= dist.get_world_size()
    val_path = Path(val_dataset)
    assert val_path.is_file()

    res_tuple = (resolution, resolution)
    cropper = CenterCropRGBImageDecoder(res_tuple, ratio=crop_ratio)
    image_pipeline = [
        cropper,
        ToTensor(),
        ToDevice(ch.device(this_device), non_blocking=True),
        ToTorchImage(),
        NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float16)
    ]

    label_pipeline = [
        IntDecoder(),
        ToTensor(),
        Squeeze(),
        ToDevice(ch.device(this_device),
        non_blocking=True)
    ]

    loader = Loader(val_dataset,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    order=OrderOption.SEQUENTIAL,
                    drop_last=False,
                    pipelines={
                        'image': image_pipeline,
                        'label': label_pipeline
                    },
                    distributed=distributed)
    return loader


if __name__ == "__main__":
    in_memory = False
    workers = 20
    batch_size = 256
    train_loader = create_train_loader(
        '/d1/dataset/ILSVRC2012/imagenet_ffcv_train.ffcv',
        workers, batch_size, False, in_memory)
    val_loader = create_val_loader(
        '/d1/dataset/ILSVRC2012/imagenet_ffcv_val.ffcv',
        workers, batch_size, 224, in_memory
    )

    # First epoch includes compilation time
    name = 'custom_lighting'
    for ims, labs in tqdm(train_loader): pass
    start_time = time.time()
    for _ in range(100):
        for ims, labs in tqdm(train_loader): pass
    print(f'Method: {name} | Shape: {ims.shape} | Time per epoch: {(time.time() - start_time) / 100:.4f}s')
