import torch
import numpy as np
import deepinv as dinv
from odl.contrib.torch import OperatorModule
import odl

# Determine the device to use for computations (GPU if available, otherwise CPU)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def get_blurring_operator(n, sigma, angle):
    """
    Create a blurring operator using Gaussian blur with specified sigma and angle.
    
    Args:
        n (int): The size of the image (assumed to be square, n x n).
        sigma (float): The standard deviation of the Gaussian blur.
        angle (float): The angle for the Gaussian blur.

    Returns:
        tuple: Contains the forward operator, adjoint operator, operator norm, and an initial reconstruction function.
    """
    # Create a physics model for Gaussian blurring
    physics = dinv.physics.Blur(
        dinv.physics.blur.gaussian_blur(sigma=(sigma), angle=angle), device=device
    )
    
    # Function to return the initial reconstruction (identity function)
    def init_recon(x):
        return x

    # Forward operator is the blurring physics
    forward_operator = physics
    
    # Adjoint operator is the adjoint of the blurring physics
    adjoint_operator = physics.A_adjoint

    # Compute the operator norm
    x = torch.zeros((1, 1, n, n), device=device)
    operator_norm = physics.compute_norm(x, tol=1e-10)**0.5
    
    return forward_operator, adjoint_operator, operator_norm, init_recon

#huber_grad = W_adj(huber_grad(W(xn), Huber_param))

def get_ct_operator(n, n_angles):
    """
    Create a CT operator using parallel beam geometry with a specified number of angles.
    
    Args:
        n (int): The size of the image (assumed to be square, n x n).
        n_angles (int): The number of angles for the CT scan.

    Returns:
        tuple: Contains the forward operator, adjoint operator, operator norm, and an initial reconstruction function.
    """
    # Define the space and geometry for the CT scan
    space = odl.uniform_discr([-n//2, -n//2], [n//2, n//2], [n, n], dtype='float32', weighting=1.0)
    angle_partition = odl.uniform_partition(0, np.pi, n_angles)
    detector_partition = odl.uniform_partition(-n, n, n_angles)
    geometry = odl.tomo.Parallel2dGeometry(angle_partition, detector_partition,)

    # Create the RayTransform operator
    T = odl.tomo.RayTransform(space, geometry)
    
    # Create the filtered backprojection (FBP) operator
    #fbp = odl.tomo.fbp_op(T, filter_type='Hann', frequency_scaling=0.8)

    # Normalize the RayTransform operator
    T_norm = T.norm(estimate=True)
    T = T / T_norm
    #fbp /= T_norm
    
    # Convert the ODL operators to PyTorch modules and move to the appropriate device
    forward_operator = OperatorModule(T).to(device)#.double()
    adjoint_operator = OperatorModule(T.adjoint).to(device)#.double()
    #A_dagger = OperatorModule(fbp).to(device)#.double()
    
    # Function to return the initial reconstruction using FBP
    def init_recon(x):
        return 0*adjoint_operator(x)
    
    # Assume the operator norm is 1 after normalization
    operator_norm = 1.0
    
    ## gradient operator and adjoint
    W = odl.Gradient(space)
    W_adj = W.adjoint
    W = OperatorModule(W).to(device)
    W_adj = OperatorModule(W_adj).to(device)

    
    return forward_operator, adjoint_operator, operator_norm, init_recon, W, W_adj
