import h5py
from chip.datasets.synthetic_to_real_dataset import Synthetic2Real
from torch.utils.data import Subset

from argparse import ArgumentParser

import torch
from pytorch_base.experiment import PyTorchExperiment
from pytorch_base.base_loss import BaseLoss

import random
from tqdm.auto import tqdm

from diffusers import UNet2DModel
from diffusers import DDPMScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup

class diffusion_loss(BaseLoss):

    def __init__(self):
        stats_names = ["loss"]
        super(diffusion_loss, self).__init__(stats_names)

    def compute_loss(self, instance, model):
        mse = torch.nn.MSELoss()
        x, y = instance
        x_0 = y
        # x_0 = torch.cat([x, y], dim=-1)
        x_0 = x_0.to(device).unsqueeze(1)
        noise = torch.randn_like(x_0).to(device)
        bs = x_0.shape[0]

        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps, (bs,), device=x_0.device
        ).long()

        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        x_t = noise_scheduler.add_noise(x_0, noise, timesteps)

        model.zero_grad()
        noise_pred = model(x_t, timesteps, return_dict=False)[0]

        # fading_factor = noise_scheduler.add_noise(torch.ones(1), torch.zeros(1), timesteps).to(device)
        # noise_factor = noise_scheduler.add_noise(torch.zeros(1), torch.ones(1), timesteps).to(device)
        # x_0_pred = (x_t - noise_factor[:, None, None, None] * noise_pred) / fading_factor[:, None, None, None]

        loss = mse(noise_pred, noise)
        return loss, {"loss": loss}


def load_model(model, model_path):
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"models loaded from checkpoint {model_path}")


if __name__ == '__main__':
    import lovely_tensors as lt

    lt.monkey_patch()

    parser = ArgumentParser(description="PyTorch experiments")
    parser.add_argument("--batch_size", default=50, type=int, help="batch size of every process")
    parser.add_argument("--epochs", default=50, type=int, help="number of epochs to train")
    parser.add_argument("--learning_rate", default=0.0001, type=float, help="learning rate")
    parser.add_argument("--weight_decay", default=0.001, type=float, help="weight decay")
    parser.add_argument("--scheduler", default="[500]", type=str, help="scheduler decrease after epochs given")
    parser.add_argument("--lr_decay", default=0.1, type=float, help="Learning rate decay")
    parser.add_argument("--wandb_exp_name", default='random_experiments', type=str, help="Experiment name in wandb")
    parser.add_argument('--wandb', action='store_true', help="Use wandb")
    parser.add_argument("--load_checkpoint", default='', type=str, help="name of models in folder checkpoints to load")
    parser.add_argument("--seed", default=-1, type=int, help="Random seed")
    args = vars(parser.parse_args())
    temp = args["scheduler"].replace(" ", "").replace("[", "").replace("]", "").split(",")
    args["scheduler"] = [int(x) for x in temp]
    args["seed"] = random.randint(0, 20000) if args["seed"] == -1 else args["seed"]

    DATA_PATH = "/mydata/chip/shared/ra.psi.ch/p17299/data_for_SDSC_v0" if torch.cuda.is_available() else "../../data/p17299"
    h5filepath = f"{DATA_PATH}/tomogram_delta.mat" if torch.cuda.is_available() else f"{DATA_PATH}/tomogram_delta_v0.mat"
    mat_file = h5py.File(h5filepath, "r")
    tomogram = mat_file.get('tomogram_delta')
    ds = Synthetic2Real(tomogram)
    ds = Subset(ds, torch.arange(23, 181))

    trainSet = ds
    testSet = ds

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps')
    # device = torch.device('cpu')
    print(device)

    model_path = f"checkpoints/diffusion_model_syntehtc2real.pt"

    model = UNet2DModel(
        sample_size=(512, 512),  # the target image resolution
        in_channels=1,  # the number of input channels, 3 for RGB images
        out_channels=1,  # the number of output channels
        layers_per_block=2,  # how many ResNet layers to use per UNet block
        block_out_channels=(64, 64, 128, 128, 256, 256),  # the number of output channels for each UNet block
        down_block_types=(
            "DownBlock2D",  # a regular ResNet downsampling block
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
            "DownBlock2D",
        ),
        up_block_types=(
            "UpBlock2D",  # a regular ResNet upsampling block
            "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
        ),
    ).to(device)

    if args['load_checkpoint'] != "":
        try:
            load_model(model, f"{args['load_checkpoint']}")
        except:
            print("model not found, initializing randomly")

    noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

    exp = PyTorchExperiment(
        train_dataset=trainSet,
        test_dataset=testSet,
        batch_size=args['batch_size'],
        model=model,
        loss_fn=diffusion_loss(),
        checkpoint_path=model_path,
        experiment_name=args['wandb_exp_name'],
        with_wandb=args['wandb'],
        num_workers=torch.get_num_threads() if torch.cuda.is_available() else 0,
        seed=args["seed"],
        args=args
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=args['learning_rate'])

    num_epochs = args['epochs']
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=500,
        num_training_steps=(len(trainSet) * num_epochs),
    )

    exp.train(args['epochs'], optimizer, milestones=args['scheduler'], gamma=args['lr_decay'], scheduler=lr_scheduler)








