from copy import deepcopy
import torch
from torch import optim
from tqdm.auto import tqdm
import torch.nn.functional as F

from chip.models.iterative_model import TomographicReconstruction
from chip.utils.sinogram import Sinogram
from chip.utils.fourier import fft_2D
from chip.utils.metrics import PSNR, RMSE


def total_variation_loss(image):
    """
    Compute the Total Variation Loss.
    image: Tensor of shape (N, H, W)
    """
    # Calculate the differences between adjacent pixel values
    pixel_dif1 = image[1:, :] - image[:-1, :]  # Difference in height
    pixel_dif2 = image[:, 1:] - image[:, :-1]  # Difference in width
    pixel_dif3 = image[1:, 1:] - image[:-1, :-1]  # Difference in diagonal1
    pixel_dif4 = image[:1, :-1] - image[:-1, 1:]  # Difference in diagonal2

    # Sum up the absolute differences
    sum_dif1 = torch.sum(torch.abs(pixel_dif1))
    sum_dif2 = torch.sum(torch.abs(pixel_dif2))
    sum_dif3 = torch.sum(torch.abs(pixel_dif3))
    sum_dif4 = torch.sum(torch.abs(pixel_dif4))

    tv_loss = sum_dif1 + sum_dif2 + sum_dif3 + sum_dif4
    return tv_loss


def finetune_sinogram_consistency(
        tr_model: TomographicReconstruction,
        target_sinogram_hr: Sinogram = None,
        target_sinogram_lr: Sinogram = None,
        lr_forward_function=None,
        batch_size=10,
        lr_image=None,
        alpha_lr=1.,
        verbose=False,
        steps=[201, 51],
        l2_reg=0,
        lr=0.5,
        tv_loss: bool = False,
        alpha_tv=0.0001,
        optimizer_name='AdamW',
        project_parameters=False,
        lr_image_loss='mse',
        return_samples=False,
        ground_truth=None
):
    mse = torch.nn.MSELoss()
    img_iterates = []

    if not isinstance(lr, list):
        lr = [lr] * len(steps)

    for t, lr in zip(steps, lr):
        if optimizer_name == 'AdamW':
            optimizer = optim.AdamW(tr_model.parameters(), lr=lr, weight_decay=l2_reg)
        if optimizer_name == 'SGD':
            optimizer = optim.SGD(tr_model.parameters(), lr=lr, weight_decay=l2_reg)

        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[t - 10], gamma=0.1)

        iterator = tqdm(range(t)) if verbose else range(t)
        for step in iterator:
            tr_model.zero_grad()
            loss = 0
            # optimize using lr sinograms
            if target_sinogram_lr is not None:
                if batch_size:
                    random_integers = [x.item() for x in torch.randperm(len(target_sinogram_lr))[:batch_size]]
                else:
                    random_integers = torch.arange(0, len(target_sinogram_lr) - 1, dtype=torch.long, device=tr_model.device)
        
                sinogram_lr = tr_model.forward(filter=lr_forward_function, sinogram_angles=target_sinogram_lr.angles[random_integers])
                # resize sinogram to target size
                target_sino = target_sinogram_lr.sinogram[random_integers]
                if sinogram_lr.shape != target_sino.shape:
                    resized_sinogram = F.interpolate(
                        sinogram_lr.unsqueeze(0).unsqueeze(0), size=target_sino.shape, mode='bilinear', align_corners=True
                    )[0, 0]
                else:
                    resized_sinogram = sinogram_lr
                loss += mse(resized_sinogram, target_sino)


            if target_sinogram_hr is not None:
                if batch_size:
                    random_integers = [x.item() for x in torch.randperm(len(target_sinogram_hr))[:batch_size]]
                else: 
                    random_integers = torch.arange(0, len(target_sinogram_hr) - 1, dtype=torch.long, device=tr_model.device)  
                sinogram_hr = tr_model.forward(sinogram_angles=target_sinogram_hr.angles[random_integers])

                # resize sinogram to target size
                target_sino = target_sinogram_hr.sinogram[random_integers]
                if sinogram_hr.shape != target_sino.shape:
                    resized_sinogram = F.interpolate(
                        sinogram_hr.unsqueeze(0).unsqueeze(0), size=target_sino.shape, mode='bilinear', align_corners=True
                    )[0, 0]
                else:
                    resized_sinogram = sinogram_hr
                loss += mse(resized_sinogram, target_sino)

            if lr_image is not None:
                if lr_image_loss == 'mse':
                    loss += alpha_lr * mse(lr_image, tr_model.get_img(filter=lr_forward_function))

                elif lr_image_loss == 'fft':
                    fft_lr_image = fft_2D(lr_image)
                    fft_prediction = fft_2D(tr_model.get_img())
                    loss += alpha_lr * mse(fft_lr_image.real, fft_prediction.real)
                    loss += alpha_lr * mse(fft_lr_image.imag, fft_prediction.imag)

                elif lr_image_loss == 'fft_masked':
                    fft_lr_image = fft_2D(lr_image)
                    mask = torch.abs(fft_lr_image) > 1
                    mask = torch.ones_like(mask, dtype=torch.bool)
                    fft_prediction = fft_2D(tr_model.get_img())
                    loss += alpha_lr * mse(fft_lr_image[mask].real, fft_prediction[mask].real)
                    loss += alpha_lr * mse(fft_lr_image[mask].imag, fft_prediction[mask].imag)

            if tv_loss:
                loss += alpha_tv * total_variation_loss(sinogram_hr)

            loss.backward()
            optimizer.step()
            scheduler.step()

            if return_samples and (step % return_samples) == 0:
                img_iterates.append(tr_model.get_img().detach().clone())

            if project_parameters:
                tr_model.img.data.clamp_(-tr_model.prior, 1 - tr_model.prior)

            if verbose:
                postfixes = {"loss": f"{loss.item():.4f}"}
                if ground_truth is not None:
                    postfixes['PSNR'] = PSNR(ground_truth, tr_model.get_img()).item()
                    postfixes['RMSE'] = RMSE(ground_truth, tr_model.get_img()).item()
                iterator.set_postfix(postfixes)

    if return_samples:
        return loss.item(), torch.stack(img_iterates)
    return loss.item()


def sample_swag(model, lr_sinogram, hr_sinogram, lr_forward_function, lr_image=None, num_samples=10, learning_rate=0.1, reset_parameters=True, batch_size=2, subsample=1, verbose=False):
    """
    SWAG method: generate samples from the SGD trajectory
    """
    if reset_parameters:
        parameters = deepcopy(model.state_dict())
    
    
    # lr_forward_function = lambda x : fourier_filtering(x, low_res_filter.to(device))

    # optimize model
    _, samples = finetune_sinogram_consistency(model,
                                               lr_forward_function=lr_forward_function,  # low resolution forward model
                                               lr_image=lr_image,  # low resolution image
                                               target_sinogram_lr=lr_sinogram,  # low resolution sinogram
                                               target_sinogram_hr=hr_sinogram,  # high resolution sinogram
                                               steps=[num_samples * subsample],
                                               lr=[learning_rate],
                                               optimizer_name='SGD',
                                               verbose=verbose,
                                               batch_size=batch_size,
                                               return_samples=subsample)

    if reset_parameters:
        model.load_state_dict(parameters)
    return samples


def resample_gaussian(samples, num_samples=10):
    """
    Fit a Gaussian to the samples and generate new samples from the Gaussian
    """
    mean = torch.mean(samples, dim=0)
    k = samples.shape[0]

    return mean + torch.matmul( (samples - mean).reshape(k, -1).T,  torch.randn(k, num_samples, device=samples.device)).T.reshape(num_samples, *mean.shape)

def sample_ensemble(lr_sinogram, hr_sinogram, lr_forward_function, im_shape=None, lr_image=None, 
                    num_samples=10, input_std = 1.0, batch_size=10, device=None, verbose=False, steps=None, lr=None):
    """
    Train an ensemble of iterative models.
    """
    samples = []
    if im_shape is None:
        im_shape = lr_image.squeeze().shape

    if steps is None:
        steps = [400, 400, 200]
    if lr is None:
        lr = [0.1, 0.01, 0.001]

    for i in tqdm(range(num_samples), disable=not verbose):
        prior = input_std * torch.randn(im_shape, device=device)
        if lr_image:
            prior += lr_image
        # here is the parameteric model, note that it is initialized with the low resolution image
        model = TomographicReconstruction(prior=prior, use_sigmoid=True)

        # optimize model
        finetune_sinogram_consistency(model, 
                                        lr_forward_function=lr_forward_function, # low resolution forward model
                                        lr_image=lr_image, # low resolution image
                                        target_sinogram_lr=lr_sinogram,  # low resolution sinogram
                                        target_sinogram_hr=hr_sinogram,  # high resolution sinogram
                                        steps=steps, 
                                        lr=lr,
                                        batch_size=batch_size,
                                        verbose=verbose)
        
        samples.append(model.get_img().detach().clone())
    return torch.stack(samples)


if __name__ == '__main__':
    import lovely_tensors as lt

    lt.monkey_patch()

    tr = TomographicReconstruction(prior=torch.randn(512, 512), use_sigmoid=True)

    loss = finetune_sinogram_consistency(
        tr,
        target_sinogram_hr=Sinogram(torch.randn(180, 512), torch.arange(180)),
        target_sinogram_lr=None,
        batch_size_lr=10,
        batch_size_hr=40,
        verbose=True,
        steps=[100],
        lr=[0.25],
        tv_loss=True
    )
