import argparse
import torch
import os
from bayes_dip.data.trafo.parallel_beam_2d_ray_trafo import get_odl_ray_trafo_parallel_beam_2d_matrix


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--im_width', type=int, help='Image width')
    parser.add_argument('--im_height', type=int, default=None, help='Image height')
    parser.add_argument('--num_angles', type=int, default=180, help='Number of angles')
    parser.add_argument('--num_det_pixels', type=int, default=None, help='Number of detector pixels')
    parser.add_argument('--output_path', type=str, help='Output path')

    args = parser.parse_args()

    num_angles = args.num_angles
    num_det_pixels = args.num_det_pixels if args.num_det_pixels is not None else args.im_width
    output_path = args.output_path

    im_shape = (args.im_width, args.im_height if args.im_height is not None else args.im_width)


    matrix = get_odl_ray_trafo_parallel_beam_2d_matrix(
                im_shape, num_angles, num_det_pixels=num_det_pixels, first_angle_zero=True, circular=True,
                angular_sub_sampling=1, impl='astra_cuda', flatten=True)

    path = os.path.join(output_path, f'ray_trafo_matrix_{im_shape[0]}_{im_shape[1]}_{num_angles}_{num_det_pixels}.pt')
    torch.save(matrix, path)
    print(f"Matrix saved to {path}")