import sys, os, pathlib, yaml
sys.path.insert(0, pathlib.Path(__file__).parent.parent.absolute() / 'bkse')
import bkse
import torch
from torch import nn
import scipy.ndimage
import numpy as np

class NoiseScheduler:
    def __init__(self, sigma_min, sigma_max):
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
    
    def __call__(self, t, noise_shape, seed=None):
        assert 0.0 <= t <= 1.0
        std = self.get_std(t)
        if seed is not None:
            rand_state = torch.get_rng_state()
            torch.random.manual_seed(seed)
            z = torch.randn(noise_shape) * std
            torch.set_rng_state(rand_state)
            return z, std
        else:    
            return torch.randn(noise_shape) * std, std
    
    def get_std(self, t):
        std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
        return std

class Blurkernel(nn.Module):
    def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0):
        super().__init__()
        self.blur_type = blur_type
        self.kernel_size = kernel_size
        self.std = std
        self.seq = nn.Sequential(
            nn.ReflectionPad2d(self.kernel_size//2),
            nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3)
        )

        self.weights_init()

    def forward(self, x):
        return self.seq(x)

    def weights_init(self):
        if self.blur_type == "gaussian":
            n = np.zeros((self.kernel_size, self.kernel_size))
            n[self.kernel_size // 2,self.kernel_size // 2] = 1
            k = scipy.ndimage.gaussian_filter(n, sigma=self.std, truncate=6.0)
            k = torch.from_numpy(k)
            self.k = k
            for name, f in self.named_parameters():
                f.data.copy_(k)
                f.requires_grad_(False)
        elif self.blur_type == "motion":
            raise ValueError('Unsupported blur type.')

    def update_weights(self, k):
        if not torch.is_tensor(k):
            k = torch.from_numpy(k)
        for name, f in self.named_parameters():
            f.data.copy_(k)
            f = f.to(k.device)
            f.requires_grad_(False)

    def get_kernel(self):
        return self.k
    
class GaussianBlurOperator:
    def __init__(self, 
                 kernel_size,
                 std_schedule,
                 from_file=None,
                ):
        self.kernel_size = kernel_size
        if from_file is None:
            self.std_schedule = std_schedule
        else:
            assert from_file is not None
            self.t_vals, self.std_vals = torch.from_numpy(np.loadtxt(from_file)[:, 0]), torch.from_numpy(np.loadtxt(from_file)[:, 1])
            self.std_schedule = lambda t: self.lerp_std(t)
        self.conv = None
        
    def update_kernel(self, t):
        self.conv = Blurkernel(blur_type='gaussian',
                       kernel_size=self.kernel_size,
                       std=self.std_from_t(t),
                      )
        self.kernel = self.conv.get_kernel().to(t.device)
        self.conv.update_weights(self.kernel.type(torch.float32))
        self.conv.to(t.device)
        
    def __call__(self, data, t, **kwargs):
        return self.forward(data, t, **kwargs)

    def forward(self, data, t, **kwargs):
        assert 0.0 <= t <= 1.0
        self.update_kernel(t)
        return self.conv(data)

    def get_kernel(self, t):
        self.update_kernel(t)
        return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)
    
    def std_from_t(self, t):
        std = self.std_schedule(t)
        std_np = float(std) 
        return std_np
    
    def lerp_std(self, t):
        assert 0.0 <= t <= 1.0
        self.std_vals = self.std_vals.to(t.device)
        self.t_vals = self.t_vals.to(t.device)
        if t == 0.0:
            return self.std_vals[0]
        elif t == 1.0:
            return self.std_vals[-1]
        else:
            # linear interpolation
            t_end_index = (self.t_vals >= t).nonzero(as_tuple=False)[0]
            t_start_index = t_end_index - 1
            std_out = self.std_vals[t_start_index] + (self.std_vals[t_end_index]- self.std_vals[t_start_index]) * (t - self.t_vals[t_start_index]) / (self.t_vals[t_end_index] - self.t_vals[t_start_index])
            return std_out

class NonlinearBlurOperator:
    def __init__(self):
        self.blur_model = self.prepare_nonlinear_blur_model('./bkse/options/generate_blur/default.yml')    
        self.rng = torch.Generator()
#         self.kernel = None
         
    def prepare_nonlinear_blur_model(self, opt_yml_path):
        '''
        Nonlinear deblur requires external codes (bkse).
        '''
        from bkse.models.kernel_encoding.kernel_wizard import KernelWizard

        with open(opt_yml_path, "r") as f:
            opt = yaml.safe_load(f)["KernelWizard"]
            model_path = 'bkse/' + opt["pretrained"]
        blur_model = KernelWizard(opt)
        blur_model.eval()
        blur_model.load_state_dict(torch.load(model_path)) 
        blur_model = blur_model
        return blur_model
    
    def forward(self, data, **kwargs):
        if data.device != next(self.blur_model.parameters()).device:
            self.blur_model.to(data.device)
        b = data.shape[0]
        
        if 'seed' in kwargs and kwargs['seed'] is not None:
            self.rng.manual_seed(kwargs['seed'])
        else:
            self.rng.seed()
        random_kernel = torch.randn(b, 512, 2, 2, generator=self.rng).to(data.device) * 1.2
            
        data_scaled = (data + 1.0) / 2.0  #[-1, 1] -> [0, 1]
        blurred = self.blur_model.adaptKernel(data_scaled, kernel=random_kernel)
        blurred = (blurred * 2.0 - 1.0).clamp(-1, 1) #[0, 1] -> [-1, 1]
        return blurred      
    
    def __call__(self, data, t, **kwargs):
        return self.forward(data, **kwargs)
        
def create_operator(config):
    if config['type'] == 'gaussian_blur':
        MIN_STD = 0.3 # below this the filter is truncated and we get identity mapping
        if config['scheduling'] == 'linear':
            std_schedule = lambda t: (config['max_std'] - MIN_STD) * t + MIN_STD
            return GaussianBlurOperator(config['kernel_size'], std_schedule)
        elif config['scheduling'] == 'from_file':
            return GaussianBlurOperator(config['kernel_size'], std_schedule=None, from_file=config['schedule_path'])
        elif config['scheduling'] == 'fixed':
            std_schedule = lambda t: config['max_std']
            return GaussianBlurOperator(config['kernel_size'], std_schedule)
    elif config['type'] == 'nonlinear_blur':
        return NonlinearBlurOperator()
    else:
        raise ValueError('Unsupported operator in config.')
        
def create_noise_schedule(config):
    if config is None:
        return None
    else:
        noise_schedule = NoiseScheduler(config['sigma_min'], config['sigma_max'])
        return noise_schedule
