from typing import Dict, Optional, Sequence, Tuple, List, Union, NamedTuple
import numpy as np
import torch
from torchvision.transforms.functional import resize, center_crop
from pl_modules.utils import str2int
from data.operators import create_operator, create_noise_schedule

rescale = lambda x: x * 2. - 1.

class ImageDataTransform:
    # Designed for ImageNet data with U-Net
    # See typical ImageNet preprocessing here: 
    # https://github.com/pytorch/examples/blob/97304e232807082c2e7b54c597615dc0ad8f6173/imagenet/main.py#L197-L198
    
    def __init__(self, 
                 is_train, 
                 operator_schedule,
                 noise_schedule=None,
                 fixed_t=None,
                 range_zero_one=False,
                ):
        self.is_train = is_train   
        self.range_zero_one = range_zero_one
        if isinstance(operator_schedule, dict):
            self.fwd_operator = create_operator(operator_schedule)
        else:
            self.fwd_operator = operator_schedule
        
        if noise_schedule is None:
            self.noise_scheduler = None
        elif isinstance(noise_schedule, dict):
            self.noise_scheduler = create_noise_schedule(noise_schedule)
        else:
            self.noise_scheduler = noise_schedule
        self.fixed_t = fixed_t

    @torch.no_grad()
    def __call__(self, 
                 image, 
                 fname=None
                ):

        # Crop image to square 
        shorter = min(image.size)
        image = center_crop(image, shorter)
        
        # Resize images to uniform size
        image = resize(image, (256, 256))
        
        # Convert to ndarray and permute dimensions to C, H, W
        image = np.array(image)
        image = image.transpose(2, 0, 1)
        
        # Normalize image to range [0, 1]
        image = image / 255.
    
        # Convert to tensor
        image = torch.from_numpy(image.astype(np.float32))
        image = image.unsqueeze(0)
        
        if not self.is_train:  # deterministic forward model for validation
            assert fname is not None
            seed = str2int(fname)
        else:
            seed = None
            
        # Generate degraded noisy images
        if self.fixed_t:
            t = torch.tensor(self.fixed_t)
        else:
            if not self.is_train:
                g = torch.Generator()
                g.manual_seed(seed)
                t = torch.rand(1, generator=g)
            else:
                t = torch.rand(1)
            
        degraded = self.fwd_operator(image, t, seed=seed).squeeze(0) 
        
        if self.noise_scheduler:
            z, noise_std = self.noise_scheduler(t, image.shape, seed=seed)
            degraded_noisy = degraded + z.to(image.device)
        else: 
            degraded_noisy = degraded
            noise_std = 0.0
        
        image = image.squeeze(0)
        degraded_noisy = degraded_noisy.squeeze(0)
        return {
                'clean': image if self.range_zero_one else rescale(image), 
                'degraded': degraded if self.range_zero_one else rescale(degraded), 
                'degraded_noisy': degraded_noisy if self.range_zero_one else rescale(degraded_noisy), 
                'noise_std': noise_std,
                't': t,
                'fname': fname,
               }