"""
Provides :class:`ParallelBeam2DRayTrafo`, as well as getters
for its matrix representation and a :class:`MatmulRayTrafo` implementation.
"""

import os
from itertools import product
from typing import Tuple
import numpy as np
import torch
from odl.contrib.torch import OperatorModule
import odl
from tqdm import tqdm
from copy import deepcopy
from bayes_dip.data.trafo.base_ray_trafo import BaseRayTrafo
from bayes_dip.data.trafo.matmul_ray_trafo import MatmulRayTrafo


def get_odl_ray_trafo_parallel_beam_2d(
        im_shape: Tuple[int, int],
        num_angles: int,
        num_det_pixels : int = None,
        first_angle_zero: bool = True,
        circular: bool = False,
        impl: str = 'astra_cuda') -> odl.tomo.RayTransform:
    """
    Return an ODL 2D parallel beam ray transform.

    Parameters
    ----------
    im_shape : 2-tuple of int
        Image shape, ``(im_0, im_1)``.
    num_angles : int
        Number of angles (to distribute from ``0`` to ``pi``).
    first_angle_zero : bool, optional
        Whether to shift all angles such that the first angle becomes ``0.``.
        If ``False``, the default angles from ODL are used, where the first angle
        is at half an angle step.
        The default is ``True``.
    impl : str, optional
        Backend for :class:`odl.tomo.RayTransform`.
        The default is ``'astra_cuda'``.
    """

    

    space = odl.uniform_discr(
                [-im_shape[0] / 2, -im_shape[1] / 2],
                [im_shape[0] / 2, im_shape[1] / 2],
                im_shape,
                dtype='float32')
    
    default_odl_geometry = odl.tomo.parallel_beam_geometry(
            space, num_angles=num_angles, det_shape=num_det_pixels)

    if first_angle_zero:
        default_first_angle = (
                default_odl_geometry.motion_grid.coord_vectors[0][0])
        angle_partition = odl.uniform_partition_fromgrid(
                odl.discr.grid.RectGrid(
                        default_odl_geometry.motion_grid.coord_vectors[0]
                        - default_first_angle))
        geometry = odl.tomo.Parallel2dGeometry(
                apart=angle_partition,
                dpart=default_odl_geometry.det_partition)
    else:
        geometry = default_odl_geometry


    correction_factor = 1
    if circular:
        correction_factor = np.sqrt(2)

    
    space2 = odl.uniform_discr(
                [-im_shape[0] / 2 * correction_factor, -im_shape[1] / 2 * correction_factor],
                [im_shape[0] / 2 * correction_factor, im_shape[1] / 2 * correction_factor],
                im_shape,
                dtype='float32')

    # proj_space = odl.uniform_discr([0, -1.], [2, 1], (num_angles, num_det_pixels), dtype='float32')
    odl_ray_trafo = odl.tomo.RayTransform(
                space2, geometry, impl=impl)

    return odl_ray_trafo


class ParallelBeam2DRayTrafo(BaseRayTrafo):
    """
    Ray transform implemented via ODL.

    Adjoint computations use the back-projection (might be slightly inaccurate).
    """

    def __init__(self,
            im_shape: Tuple[int, int],
            num_angles: int,
            first_angle_zero: bool = True,
            angular_sub_sampling: int = 1,
            num_det_pixels: int = None,
            circular: bool = False,
            impl: str = 'astra_cuda'):
        """
        Parameters
        ----------
        im_shape : 2-tuple of int
            Image shape, ``(im_0, im_1)``.
        num_angles : int
            Number of angles (to distribute from ``0`` to ``pi``).
        first_angle_zero : bool, optional
            Whether to shift all angles such that the first angle becomes ``0.``.
            If ``False``, the default angles from ODL are used, where the first angle
            is at half an angle step.
            The default is ``True``.
        angular_sub_sampling : int, optional
            Sub-sampling factor for the angles.
            The default is ``1`` (no sub-sampling).
        impl : str, optional
            Backend for :class:`odl.tomo.RayTransform`.
            The default is ``'astra_cuda'``.
        """
        odl_ray_trafo_full = get_odl_ray_trafo_parallel_beam_2d(
                im_shape, num_angles, first_angle_zero=first_angle_zero,
                impl=impl, num_det_pixels=num_det_pixels, circular=circular)
        odl_ray_trafo = odl.tomo.RayTransform(
                odl_ray_trafo_full.domain,
                odl_ray_trafo_full.geometry[::angular_sub_sampling], impl=impl)
        odl_fbp = odl.tomo.fbp_op(odl_ray_trafo)

        obs_shape = odl_ray_trafo.range.shape

        super().__init__(im_shape=im_shape, obs_shape=obs_shape)

        self.odl_ray_trafo = odl_ray_trafo
        self._angles = odl_ray_trafo.geometry.angles

        self.ray_trafo_module = OperatorModule(odl_ray_trafo)
        self.ray_trafo_module_adj = OperatorModule(odl_ray_trafo.adjoint)
        self.fbp_module = OperatorModule(odl_fbp)

    @property
    def angles(self) -> np.ndarray:
        """:class:`np.ndarray` : The angles (in radian)."""
        return self._angles

    def trafo(self, x):
        return self.ray_trafo_module(x)

    def trafo_adjoint(self, observation):
        return self.ray_trafo_module_adj(observation)

    trafo_flat = BaseRayTrafo._trafo_flat_via_trafo
    trafo_adjoint_flat = BaseRayTrafo._trafo_adjoint_flat_via_trafo_adjoint

    def fbp(self, observation):
        return self.fbp_module(observation)

def get_odl_ray_trafo_parallel_beam_2d_matrix(
        im_shape: Tuple[int, int],
        num_angles: int,
        num_det_pixels: int = None,
        first_angle_zero: bool = True,
        angular_sub_sampling: int = 1,
        circular: bool = False,
        impl: str = 'astra_cuda',
        flatten: bool = True) -> np.ndarray:
    """
    Return the matrix representation of an ODL 2D parallel beam ray transform.

    See documentation of :class:`ParallelBeam2DRayTrafo` for
    documentation of the parameters not documented here.

    Parameters
    ----------
    flatten : bool, optional
        If ``True``, the observation dimensions and image dimensions are flattened,
        the resulting shape is ``(np.prod(obs_shape), np.prod(im_shape))``);
        if ``False``, the shape is ``obs_shape + im_shape``.
        The default is ``True``.
    """

    odl_ray_trafo_full = get_odl_ray_trafo_parallel_beam_2d(
                im_shape, num_angles, num_det_pixels=num_det_pixels, first_angle_zero=first_angle_zero, circular=circular,
                impl=impl)
    # odl_ray_trafo = odl.tomo.RayTransform(
    #         odl_ray_trafo_full.domain,
    #         odl_ray_trafo_full.geometry[::angular_sub_sampling], impl=impl)
    obs_shape = odl_ray_trafo_full.range.shape

    matrix = np.zeros(obs_shape + im_shape, dtype=np.float32)
    x = np.zeros(im_shape, dtype=np.float32)
    for i0, i1 in tqdm(product(range(im_shape[0]), range(im_shape[1])),
            total=im_shape[0] * im_shape[1],
            desc='generating ray transform matrix'):
        x[i0, i1] = 1.
        matrix[:, :, i0, i1] = odl_ray_trafo_full(x)
        x[i0, i1] = 0.

    # matrix = odl.operator.oputils.matrix_representation(
    #         odl_ray_trafo_full)

    if angular_sub_sampling != 1:
        matrix = matrix[::angular_sub_sampling]

    if flatten:
        matrix = matrix.reshape(-1, im_shape[0] * im_shape[1])

    return matrix


def get_parallel_beam_2d_matmul_ray_trafo(
        im_shape: Tuple[int, int],
        num_angles: int,
        num_det_pixels : int  =None,
        first_angle_zero: bool = True,
        angular_sub_sampling: int = 1,
        circular: bool = False,
        matrix_ray_trafo: bool = True,
        matrix_path: str = None,
        impl: str = 'astra_cuda') -> MatmulRayTrafo:
    """
    Return a :class:`bayes_dip.data.MatmulRayTrafo` with the matrix
    representation of an ODL 2D parallel beam ray transform.

    See documentation of :class:`ParallelBeam2DRayTrafo` for
    documentation of the parameters.
    """

    if not matrix_ray_trafo:
        return ParallelBeam2DRayTrafo(im_shape, num_angles, num_det_pixels=num_det_pixels, first_angle_zero=first_angle_zero, circular=circular,
            impl=impl)

    odl_ray_trafo_full = get_odl_ray_trafo_parallel_beam_2d(
            im_shape, num_angles, num_det_pixels=num_det_pixels, first_angle_zero=first_angle_zero, circular=circular,
            impl=impl)
    odl_ray_trafo = odl.tomo.RayTransform(
            odl_ray_trafo_full.domain,
            odl_ray_trafo_full.geometry[::angular_sub_sampling], impl=impl)
    
    odl_fbp = odl.tomo.fbp_op(odl_ray_trafo)

    obs_shape = odl_ray_trafo.range.shape
    angles = odl_ray_trafo.geometry.angles

    fbp_module = OperatorModule(odl_fbp)

    if matrix_path is None:
        matrix = get_odl_ray_trafo_parallel_beam_2d_matrix(
                im_shape, num_angles, num_det_pixels=num_det_pixels, first_angle_zero=first_angle_zero, circular=circular,
                angular_sub_sampling=angular_sub_sampling, impl=impl, flatten=True)
    else:
        if os.path.isdir(matrix_path):
            matrix_path = os.path.join(matrix_path, f'ray_trafo_matrix_{im_shape[0]}_{im_shape[1]}_{obs_shape[0]}_{obs_shape[1]}.pt')
        if not os.path.exists(matrix_path):
            raise FileNotFoundError(f'Ray trafo matrix file {matrix_path} does not exist.')
        matrix = torch.load(matrix_path)
        print(f'Loaded ray trafo matrix from {matrix_path}.')

    ray_trafo = MatmulRayTrafo(im_shape, obs_shape, matrix, fbp_fun=fbp_module, angles=angles)

    return ray_trafo

def get_parallel_beam_2d_matmul_ray_trafos_bayesian_exp_design(
        im_shape: Tuple[int, int],
        num_angles: int,
        angular_sub_sampling: int = 1,
        num_det_pixels: int = None, # ToDo
        circular: bool = False, # ToDo
        impl: str = 'astra_cuda') -> Tuple[MatmulRayTrafo,MatmulRayTrafo]:
    
    """
    Return a pair of :class:`bayes_dip.data.MatmulRayTrafo`, the first one 
    with the full matrix representation of an ODL 2D parallel beam ray transform, 
    the second one with a `angular_sub_sampling`.
 
    See documentation of :class:`ParallelBeam2DRayTrafo` for
    documentation of the parameters.
    """

    # `first_angle_zero = True` to ensure matching angles
    odl_ray_trafo_full = get_odl_ray_trafo_parallel_beam_2d(
            im_shape, num_angles, first_angle_zero=True,
            impl=impl)
    odl_fbp_full = odl.tomo.fbp_op(odl_ray_trafo_full)
    fbp_full_module = OperatorModule(odl_fbp_full)

    obs_shape_full = odl_ray_trafo_full.range.shape
    angles_full = odl_ray_trafo_full.geometry.angles

    # `first_angle_zero = True` to ensure matching angles  
    matrix_full = get_odl_ray_trafo_parallel_beam_2d_matrix(
            im_shape, num_angles, first_angle_zero=True,
            angular_sub_sampling=1, impl=impl, flatten=True
            )

    ray_trafo_full = MatmulRayTrafo(im_shape, obs_shape_full,
            matrix_full, fbp_fun=fbp_full_module, angles=angles_full)

    odl_ray_trafo = odl.tomo.RayTransform(
                odl_ray_trafo_full.domain,
                odl_ray_trafo_full.geometry[::angular_sub_sampling], impl=impl
            )
    odl_fbp = odl.tomo.fbp_op(odl_ray_trafo)
    fbp_module = OperatorModule(odl_fbp)

    obs_shape = odl_ray_trafo.range.shape
    angles = odl_ray_trafo.geometry.angles
    matrix = matrix_full.reshape(
                    obs_shape_full + im_shape)
    matrix = deepcopy(matrix[::angular_sub_sampling].reshape(
                                        -1, im_shape[0] * im_shape[1])
                                )
    
    ray_trafo = MatmulRayTrafo(im_shape, obs_shape, matrix, fbp_fun=fbp_module, angles=angles)

    return ray_trafo_full, ray_trafo