import sys
import os

import h5py

from chip.datasets.sinogram_dataset import SinogramsDS
from chip.datasets.superres_dataset import SuperresolutionDS

from argparse import ArgumentParser

import torch
from pytorch_base.experiment import PyTorchExperiment
from pytorch_base.base_loss import BaseLoss

import random
import torch.nn as nn
import matplotlib.pyplot as plt

from torch.utils.data import Dataset

import torch.nn.functional as F

from chip.utils.fourier import fft_2D, ifft_2D

b_fft_2D = torch.vmap(fft_2D)
b_ifft_2D = torch.vmap(ifft_2D)
class filter_model(nn.Module):
    def __init__(self):
        super(filter_model, self).__init__()
        self.filter = nn.Parameter(self.ram_lak_filter_torch(512, 1.))

    def ram_lak_filter_torch(self, n, stride):
        """
        Create a Ram-Lak filter using PyTorch operations.
        """
        freqs = torch.fft.fftfreq(n, stride)
        filter = torch.abs(torch.abs(freqs) - torch.max(freqs))
        return filter.float()
    def forward(self, image):
        fft_image = b_fft_2D(image)
        return b_ifft_2D(fft_image * self.filter).real

def batch_rotate(image, angles):
    image_device = image.device
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps')
    rotation_matrix = torch.stack([
        torch.stack([
            torch.cos(torch.deg2rad(angles).to(device)),
            -torch.sin(torch.deg2rad(angles).to(device)),
            torch.zeros_like(angles, device=device)
        ], 1),
        torch.stack([
            torch.sin(torch.deg2rad(angles).to(device)),
            torch.cos(torch.deg2rad(angles).to(device)),
            torch.zeros_like(angles, device=device)
        ], 1)
    ], 1)
    current_grid = F.affine_grid(
        rotation_matrix.to(device),
        (len(angles), *image.shape),
        align_corners=False
    )

    rotated_image = F.grid_sample(
        image.unsqueeze(1).to(device),
        current_grid.repeat(1, 1, 1, 1),
        align_corners=False
    )
    return rotated_image.to(image_device)


class CombinedDataset(Dataset):
    def __init__(self, dataset1, dataset2):
        self.dataset1 = dataset1
        self.dataset2 = dataset2
        # Ensure both datasets have the same length
        assert len(dataset1) == len(dataset2), "Datasets should have the same length"

    def __getitem__(self, index):
        # Get the k-th element from both datasets
        data1 = self.dataset1[index]
        data2 = self.dataset2[index][1]

        # Return a tuple or a combined structure
        return data1, data2

    def __len__(self):
        # Return the length of the datasets
        return len(self.dataset1)

def FBP(sinogram, angles):
    assert (len(sinogram) == len(angles))
    with torch.no_grad():
        fbp = sinogram.repeat(512, 1, 1).permute(1, 0, 2)
        image = batch_rotate(fbp, -angles)

    return image.squeeze(1)

b_FBP = torch.vmap(FBP)

class diffusion_loss(BaseLoss):

    def __init__(self):
        stats_names = ["loss"]
        super(diffusion_loss, self).__init__(stats_names)

    def compute_loss(self, instance, model):
        x_0, target = instance
        mse = torch.nn.MSELoss()
        x_0 = x_0.to(device)
        target = target.to(device)
        bs = x_0.shape[0]
        with torch.no_grad():
            theta = torch.linspace(0, 180 - 180 / x_0.shape[1], x_0.shape[1])
            fbp_x_0 = b_FBP(x_0, theta.repeat(bs, 1))

        model.zero_grad()
        output = model(fbp_x_0.reshape(-1, 512, 512)).reshape(bs, -1, 512, 512)
        output = torch.sum(output, dim = 1)

        loss = mse(output.reshape(target.shape), target)

        # Display the image using plt.imshow
        plt.imshow(output.reshape(target.shape)[0].detach().cpu().numpy(), cmap='gray')
        plt.savefig('image.png', bbox_inches='tight', pad_inches=0.0)
        plt.imshow(target[0].detach().cpu().numpy(), cmap='gray')
        plt.savefig('image_orig.png', bbox_inches='tight', pad_inches=0.0)

        return loss, {"loss": loss}


def load_model(model, model_path):
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"models loaded from checkpoint {model_path}")


if __name__ == '__main__':
    import lovely_tensors as lt

    lt.monkey_patch()

    parser = ArgumentParser(description="PyTorch experiments")
    parser.add_argument("--batch_size", default=50, type=int, help="batch size of every process")
    parser.add_argument("--epochs", default=50, type=int, help="number of epochs to train")
    parser.add_argument("--learning_rate", default=0.001, type=float, help="learning rate")
    parser.add_argument("--weight_decay", default=0.001, type=float, help="weight decay")
    parser.add_argument("--scheduler", default="[500]", type=str, help="scheduler decrease after epochs given")
    parser.add_argument("--lr_decay", default=0.1, type=float, help="Learning rate decay")
    parser.add_argument("--wandb_exp_name", default='random_experiments', type=str, help="Experiment name in wandb")
    parser.add_argument('--wandb', action='store_true', help="Use wandb")
    parser.add_argument("--load_checkpoint", default='', type=str, help="name of models in folder checkpoints to load")
    parser.add_argument("--seed", default=-1, type=int, help="Random seed")
    args = vars(parser.parse_args())
    temp = args["scheduler"].replace(" ", "").replace("[", "").replace("]", "").split(",")
    args["scheduler"] = [int(x) for x in temp]
    args["seed"] = random.randint(0, 20000) if args["seed"] == -1 else args["seed"]

    files = os.listdir('data/imgs_synthetic')
    tomo_ds = SuperresolutionDS(files, data_path="data")

    sino_ds = SinogramsDS("data/synthetic_sinograms.h5")

    ds = CombinedDataset(sino_ds, tomo_ds)


    trainSet = torch.utils.data.Subset(ds, range(0, round(0.9 * len(ds))))
    testSet = torch.utils.data.Subset(ds, range(round(0.9 * len(ds)), len(ds)))

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    # device = torch.device('cpu')
    print(device)

    model_path = f"checkpoints/diffusion_model_tomogram.pt"

    factor = 8
    model = filter_model().to(device)

    if args['load_checkpoint'] != "":
        load_model(model, f"{args['load_checkpoint']}")

    exp = PyTorchExperiment(
        train_dataset=trainSet,
        test_dataset=testSet,
        batch_size=args['batch_size'],
        model=model,
        loss_fn=diffusion_loss(),
        checkpoint_path=model_path,
        experiment_name=args['wandb_exp_name'],
        with_wandb=args['wandb'],
        num_workers=0,
        seed=args["seed"],
        args=args
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=args['learning_rate'])

    num_epochs = 50

    exp.train(args['epochs'], optimizer, milestones=args['scheduler'], gamma=args['lr_decay'])








