import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../chip-project')))
from argparse import ArgumentParser

import torch
from pytorch_base.experiment import PyTorchExperiment
from pytorch_base.base_loss import BaseLoss

import h5py
import hdf5plugin

import random

from diffusers import UNet2DModel
from diffusers import DDPMScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup

from torch.utils.data import Dataset
from torchvision import transforms

from chip.datasets.sinogram_dataset import SinogramsDS


from torch.utils.data import Dataset
from torchvision import transforms

from chip.training.iterative_reconstruction import total_variation_loss

batched_total_variation = torch.vmap(total_variation_loss)
class MyCustomDataset(Dataset):
    def __init__(self, data):
        """
        Args:
            data (list or array): Your data.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.data = data
        self.train_transform = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
            ]
        )

    def test_transform(self):
        return transforms.Compose([
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]
        shift = random.randint(0, x.shape[0] - 1)
        sample = torch.roll(x, shift, dims=0)
        sample[:shift, :] = torch.flip(sample[:shift, :], dims=[1])
        sample = self.train_transform(sample.unsqueeze(0)).squeeze(0)

        return normalize(sample)

def reverse_normalization(x):
    # x = 6.239 * (x + 1) / 2
    # x = torch.exp(x) - 1
    return x

def normalize(x):
    x /= torch.mean(x)
    return x

    # x = torch.log(x + 1)
    # return 2 * (x / 6.239) - 1

class diffusion_loss(BaseLoss):

    def __init__(self):
        stats_names = ["loss", "loss_consistency"]
        super(diffusion_loss, self).__init__(stats_names)

    def compute_loss(self, x_0, model):
        mse = torch.nn.MSELoss()
        x_0 = x_0.to(device).unsqueeze(1)
        noise = torch.randn_like(x_0).to(device)
        bs = x_0.shape[0]

        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps, (bs,), device=x_0.device
        ).long()

        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        x_t = noise_scheduler.add_noise(x_0, noise, timesteps)

        model.zero_grad()
        noise_pred = model(x_t, timesteps, return_dict=False)[0]

        loss_noise = mse(noise_pred, noise)

        fading_factor = noise_scheduler.add_noise(torch.ones(1), torch.zeros(1), timesteps).to(device)
        noise_factor = noise_scheduler.add_noise(torch.zeros(1), torch.ones(1), timesteps).to(device)

        x_0_pred = (x_t - noise_factor[:, None, None, None] * noise_pred) / fading_factor[:, None, None, None]
        loss_consistency = mse(torch.mean(x_0_pred).view(-1), torch.ones(1, device=device).view(-1))

        loss = loss_noise + 0.1 * loss_consistency
        return loss, {"loss": loss_noise, "loss_consistency": loss_noise}


def load_model(model, model_path):
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"models loaded from checkpoint {model_path}")


if __name__ == '__main__':
    import lovely_tensors as lt

    lt.monkey_patch()

    parser = ArgumentParser(description="PyTorch experiments")
    parser.add_argument("--batch_size", default=15, type=int, help="batch size of every process")
    parser.add_argument("--epochs", default=250, type=int, help="number of epochs to train")
    parser.add_argument("--learning_rate", default=0.00005, type=float, help="learning rate")
    parser.add_argument("--scheduler", default="[500]", type=str, help="scheduler decrease after epochs given")
    parser.add_argument("--lr_decay", default=0.1, type=float, help="Learning rate decay")
    parser.add_argument("--wandb_exp_name", default='random_experiments', type=str, help="Experiment name in wandb")
    parser.add_argument('--wandb', action='store_true', help="Use wandb")
    parser.add_argument("--load_checkpoint", default='', type=str, help="name of models in folder checkpoints to load")
    parser.add_argument("--seed", default=-1, type=int, help="Random seed")
    args = vars(parser.parse_args())
    temp = args["scheduler"].replace(" ", "").replace("[", "").replace("]", "").split(",")
    args["scheduler"] = [int(x) for x in temp]
    args["seed"] = random.randint(0, 20000) if args["seed"] == -1 else args["seed"]

    ds = SinogramsDS("data/synthetic_sinograms.h5")

    trainSet_raw = torch.utils.data.Subset(ds, range(0, round(0.9 * len(ds))))
    testSet_raw = torch.utils.data.Subset(ds, range(round(0.9 * len(ds)), len(ds)))
    trainSet = MyCustomDataset(trainSet_raw)
    testSet = MyCustomDataset(testSet_raw)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps')
    # device = torch.device('cpu')
    print(device)

    model_path = f"checkpoints/diffusion_model_sinogram.pt"

    model = UNet2DModel(
        sample_size=(256, 512),  # the target image resolution
        in_channels=1,  # the number of input channels, 3 for RGB images
        out_channels=1,  # the number of output channels
        layers_per_block=2,  # how many ResNet layers to use per UNet block
        block_out_channels=(64, 64, 128, 128, 256, 256),  # the number of output channels for each UNet block
        down_block_types=(
            "DownBlock2D",  # a regular ResNet downsampling block
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
            "DownBlock2D",
        ),
        up_block_types=(
            "UpBlock2D",  # a regular ResNet upsampling block
            "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
        ),
    ).to(device)

    if args['load_checkpoint'] != "":
        try:
            load_model(model, f"{args['load_checkpoint']}")
        except:
            print("model not found initializing randomly")

    noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

    exp = PyTorchExperiment(
        train_dataset=trainSet,
        test_dataset=testSet,
        batch_size=args['batch_size'],
        model=model,
        loss_fn=diffusion_loss(),
        checkpoint_path=model_path,
        experiment_name=args['wandb_exp_name'],
        with_wandb=args['wandb'],
        num_workers=torch.get_num_threads() if torch.cuda.is_available() else 0,
        seed=args["seed"],
        args=args
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=args['learning_rate'])

    num_epochs = args['epochs']
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=500,
        num_training_steps=(len(trainSet) * num_epochs),
    )

    exp.train(args['epochs'], optimizer, milestones=args['scheduler'], gamma=args['lr_decay'], scheduler=lr_scheduler)








