import torch
import torch.nn as nn
from typing import Optional, Tuple, Union


class GaussianRandomField(nn.Module):
    def __init__(self, input_dim=2, grid_size=10, lengthscale=0.2,
                 output_scale=1.0, device='cpu'):
        super().__init__()

        self.input_dim = input_dim
        self.grid_size = grid_size
        self.device = device
        self.jitter_value = 1e-3

        self.coords = torch.linspace(0, 1, grid_size, device=device)

        diff = self.coords.unsqueeze(0) - self.coords.unsqueeze(1)
        kernel = output_scale * torch.exp(-(diff ** 2) / (2 * lengthscale ** 2))

        jitter = self.jitter_value * torch.eye(grid_size, device=device)
        kernel = kernel + jitter

        self.kernel = kernel

        self.cholesky = torch.linalg.cholesky(kernel.to('cpu')).to(device)
        self.cholesky_inv = torch.linalg.inv(self.cholesky)

    def forward(self, input_tensor=None, shape=None):
        if input_tensor is not None:
            batch_size, channels, *spatial_dims = input_tensor.shape
        elif shape is not None:
            batch_size, channels, *spatial_dims = shape
        else:
            raise ValueError("Either input_tensor or shape must be provided.")

        expected_dims = [self.grid_size] * self.input_dim
        if spatial_dims != expected_dims:
            raise ValueError(f"Expected spatial dims {expected_dims}, got {spatial_dims}")

        Z = torch.randn(batch_size, channels, *spatial_dims, device=self.device)

        sample = Z
        for dim in range(self.input_dim):
            sample = torch.tensordot(self.cholesky, sample, dims=([1], [2 + dim]))
            sample = sample.movedim(0, 2 + dim)

        return Z, sample

    def sample_as(self, tensor):
        _, sample = self(tensor)
        return sample


class PatchBasedGPRegressor(nn.Module):
    def __init__(
        self,
        patch_size: Tuple[int, int] = (64, 64),
        overlap: int = 8,
        kernel_length_scale: float = 1.0,
        kernel_variance: float = 1.0,
        noise_variance: float = 0.01,
        device: str = 'cuda',
        use_cholesky: bool = True,
        jitter: float = 1e-6
    ):
        super().__init__()
        
        self.patch_size = patch_size
        self.overlap = overlap
        self.device = torch.device(device)
        self.use_cholesky = use_cholesky
        self.jitter = jitter
        
        self.register_buffer('kernel_length_scale', torch.tensor(kernel_length_scale, device=self.device))
        self.register_buffer('kernel_variance', torch.tensor(kernel_variance, device=self.device))
        self.register_buffer('noise_variance', torch.tensor(noise_variance, device=self.device))
        
    def gaussian_kernel(self, X1: torch.Tensor, X2: torch.Tensor) -> torch.Tensor:
        dists = torch.cdist(X1, X2, p=2) ** 2
        return self.kernel_variance * torch.exp(-0.5 * dists / self.kernel_length_scale ** 2)
            
    def extract_patches(
        self, 
        image: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[list, list, list]:
        if image.dim() == 2:
            image = image.unsqueeze(0)
            
        C, H, W = image.shape
        patch_h, patch_w = self.patch_size
        stride_h = patch_h - self.overlap
        stride_w = patch_w - self.overlap
        
        patches = []
        masks_list = []
        positions = []
        
        n_patches_h = (H - self.overlap) // stride_h
        n_patches_w = (W - self.overlap) // stride_w
        
        if (H - self.overlap) % stride_h != 0:
            n_patches_h += 1
        if (W - self.overlap) % stride_w != 0:
            n_patches_w += 1
            
        for i in range(n_patches_h):
            for j in range(n_patches_w):
                start_h = i * stride_h
                start_w = j * stride_w
                end_h = min(start_h + patch_h, H)
                end_w = min(start_w + patch_w, W)
                
                if end_h - start_h < patch_h:
                    start_h = max(0, end_h - patch_h)
                if end_w - start_w < patch_w:
                    start_w = max(0, end_w - patch_w)
                    
                patch = image[:, start_h:end_h, start_w:end_w]
                patches.append(patch)
                positions.append((start_h, start_w))
                
                if mask is not None:
                    mask_patch = mask[start_h:end_h, start_w:end_w]
                    masks_list.append(mask_patch)
                else:
                    masks_list.append(torch.ones(patch.shape[1:], device=self.device))
                    
        return patches, masks_list, positions
    
    def reconstruct_from_patches(
        self,
        patches: list,
        positions: list,
        image_shape: Tuple[int, int, int]
    ) -> torch.Tensor:
        C, H, W = image_shape
        patch_h, patch_w = self.patch_size
        
        output = torch.zeros(image_shape, device=self.device)
        weights = torch.zeros((H, W), device=self.device)
        
        blend_h = torch.ones(patch_h, device=self.device)
        blend_w = torch.ones(patch_w, device=self.device)
        
        if self.overlap > 0:
            ramp = torch.linspace(0, 1, self.overlap, device=self.device)
            blend_h[:self.overlap] = ramp
            blend_h[-self.overlap:] = ramp.flip(0)
            blend_w[:self.overlap] = ramp
            blend_w[-self.overlap:] = ramp.flip(0)
            
        blend_weight = blend_h.unsqueeze(1) * blend_w.unsqueeze(0)
        
        for patch, (start_h, start_w) in zip(patches, positions):
            end_h = start_h + patch.shape[1]
            end_w = start_w + patch.shape[2]
            
            actual_h, actual_w = patch.shape[1], patch.shape[2]
            weight_slice = blend_weight[:actual_h, :actual_w]
            
            output[:, start_h:end_h, start_w:end_w] += patch * weight_slice
            weights[start_h:end_h, start_w:end_w] += weight_slice
            
        weights = torch.clamp(weights, min=1e-8)
        output = output / weights.unsqueeze(0)
        
        return output
    
    def fit_predict_patch(
        self,
        patch: torch.Tensor,
        mask: torch.Tensor,
        return_std: bool = False
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        C, H, W = patch.shape
        
        y_coords, x_coords = torch.meshgrid(
            torch.arange(H, device=self.device, dtype=torch.float32),
            torch.arange(W, device=self.device, dtype=torch.float32),
            indexing='ij'
        )
        coords = torch.stack([y_coords, x_coords], dim=-1)  # (H, W, 2)
        coords_flat = coords.reshape(-1, 2)  # (H*W, 2)
        
        coords_flat = coords_flat / torch.tensor([H, W], device=self.device, dtype=torch.float32)
        
        mask_flat = mask.reshape(-1)
        obs_idx = mask_flat.bool()
        
        if obs_idx.sum() == 0:
            if return_std:
                std = torch.ones_like(patch) * torch.sqrt(self.kernel_variance)
                return torch.zeros_like(patch), std
            return torch.zeros_like(patch)
        
        X_obs = coords_flat[obs_idx]  # (N_obs, 2)
        X_all = coords_flat  # (N_all, 2)
        
        predicted_channels = []
        std_channels = [] if return_std else None
        
        K_obs_obs = self.gaussian_kernel(X_obs, X_obs)
        K_obs_obs += (self.noise_variance + self.jitter) * torch.eye(
            K_obs_obs.shape[0], device=self.device
        )
        
        K_all_obs = self.gaussian_kernel(X_all, X_obs)  # (N_all, N_obs)
        
        for c in range(C):
            y_obs = patch[c].reshape(-1)[obs_idx]  # (N_obs,)
            
            
            exception_occured = False
            if self.use_cholesky:
                try:
                    L = torch.linalg.cholesky(K_obs_obs)
                    alpha = torch.cholesky_solve(y_obs.unsqueeze(1), L)
                except RuntimeError:
                    print('Fallback to regular solve, Cholesky fails')
                    exception_occured = True
                    alpha = torch.linalg.solve(K_obs_obs, y_obs.unsqueeze(1))
            else:
                alpha = torch.linalg.solve(K_obs_obs, y_obs.unsqueeze(1))
            
            mean = torch.matmul(K_all_obs, alpha).squeeze()  # (N_all,)
            predicted_channels.append(mean.reshape(H, W))
            
            if return_std:
                K_all_all_diag = self.kernel_variance * torch.ones(X_all.shape[0], device=self.device)
                
                if self.use_cholesky:
                    try:
                        if exception_occured:
                            raise RuntimeError
                        v = torch.linalg.solve_triangular(L, K_all_obs.T, upper=False)
                        var = K_all_all_diag - (v ** 2).sum(dim=0)
                    except RuntimeError:
                        K_inv_K_all_obs = torch.linalg.solve(K_obs_obs, K_all_obs.T)
                        var = K_all_all_diag - (K_all_obs * K_inv_K_all_obs.T).sum(dim=1)
                else:
                    K_inv_K_all_obs = torch.linalg.solve(K_obs_obs, K_all_obs.T)
                    var = K_all_all_diag - (K_all_obs * K_inv_K_all_obs.T).sum(dim=1)
                
                var = torch.clamp(var, min=0.0)  # Numerical stability
                std = torch.sqrt(var).reshape(H, W)
                
                std_channels.append(std)
        
        predicted_patch = torch.stack(predicted_channels, dim=0)
        
        if return_std:
            std_patch = torch.stack(std_channels, dim=0)
            return predicted_patch, std_patch
        
        return predicted_patch
    
    def predict(
        self,
        image: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        return_std: bool = False
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        original_shape = image.shape
        if image.dim() == 2:
            image = image.unsqueeze(0)
            
        if mask is None:
            mask = torch.ones(image.shape[1:], device=self.device)
            
        image = image.to(self.device)
        mask = mask.to(self.device)
        
        patches, masks, positions = self.extract_patches(image, mask)
        
        predicted_patches = []
        std_patches = [] if return_std else None
        
        for patch, mask_patch in zip(patches, masks):
            if return_std:
                pred_patch, std_patch = self.fit_predict_patch(patch, mask_patch, return_std=True)
                predicted_patches.append(pred_patch)
                std_patches.append(std_patch)
            else:
                pred_patch = self.fit_predict_patch(patch, mask_patch, return_std=False)
                predicted_patches.append(pred_patch)
        
        predicted_image = self.reconstruct_from_patches(
            predicted_patches, positions, image.shape
        )
        
        if return_std:
            std_image = self.reconstruct_from_patches(
                std_patches, positions, image.shape
            )
            if len(original_shape) == 2:
                return predicted_image.squeeze(0), std_image.squeeze(0)
            return predicted_image, std_image
        
        if len(original_shape) == 2:
            return predicted_image.squeeze(0)
        return predicted_image
    
    def sample_posterior(
        self,
        image: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        n_samples: int = 1,
        seed: Optional[int] = None
    ) -> torch.Tensor:
        if seed is not None:
            torch.manual_seed(seed)
            
        original_shape = image.shape
        if image.dim() == 2:
            image = image.unsqueeze(0)
            
        mean, std = self.predict(image, mask, return_std=True)
        
        samples = []
        for _ in range(n_samples):
            noise = torch.randn_like(mean)
            sample = mean + std * noise
            samples.append(sample)
            
        samples = torch.stack(samples, dim=0)
        
        if len(original_shape) == 2:
            samples = samples.squeeze(1)
            
        return samples
    
    def forward(
        self,
        image: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Forward pass for nn.Module compatibility."""
        return self.predict(image, mask, return_std=False)

