import sys
import os

from chip.datasets.superres_dataset import SuperresolutionDS
from chip.datasets.tomogram_dataset import TomogramDataset
from chip.datasets.tiff_tomogram_dataset import TIFFDataset
from chip.datasets.nii_tomogram_dataset import NiiDataset

from argparse import ArgumentParser

import torch
from pytorch_base.experiment import PyTorchExperiment
from pytorch_base.base_loss import BaseLoss

import random
from tqdm.auto import tqdm

from diffusers import UNet2DModel
from diffusers import DDPMScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup

import h5py
import hdf5plugin

from torch.utils.data import Dataset

def get_dataset(kwargs, dataset_type):
    dataset_class = TomogramDataset if dataset_type == 'h5' else TIFFDataset
    dataset_class = NiiDataset if dataset_type == 'nii' else dataset_class

    if dataset_type == 'tiff' and 'im_size' not in kwargs:
        kwargs['im_size'] = 512

    if dataset_type == 'nii' and 'clip_range' not in kwargs:
        kwargs['im_size'] = 512
        kwargs['clip_range'] = [3e4, 5e4]

    dataset = dataset_class(**kwargs)
    torch.manual_seed(0)
    perm = torch.randperm(len(dataset))
    trainSet = torch.utils.data.Subset(dataset, perm[:round(0.95 * len(dataset))])
    testSet = torch.utils.data.Subset(dataset, perm[round(0.95 * len(dataset)):])
    return trainSet, testSet

class diffusion_loss(BaseLoss):

    def __init__(self):
        stats_names = ["loss"]
        super(diffusion_loss, self).__init__(stats_names)

    def compute_loss(self, instance, model):
        mse = torch.nn.MSELoss()
        _, x_0, _ = instance
        x_0 = x_0.to(device)
        noise = torch.randn_like(x_0).to(device)
        bs = x_0.shape[0]

        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps, (bs,), device=x_0.device
        ).long()

        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        x_t = noise_scheduler.add_noise(x_0, noise, timesteps)

        model.zero_grad()
        noise_pred = model(x_t, timesteps, return_dict=False)[0]

        loss = mse(noise_pred, noise)
        return loss, {"loss": loss}


def load_model(model, model_path):
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"models loaded from checkpoint {model_path}")


if __name__ == '__main__':
    import lovely_tensors as lt

    lt.monkey_patch()

    parser = ArgumentParser(description="PyTorch experiments")
    parser.add_argument("--batch_size", default=50, type=int, help="batch size of every process")
    parser.add_argument("--epochs", default=50, type=int, help="number of epochs to train")
    parser.add_argument("--learning_rate", default=0.0001, type=float, help="learning rate")
    parser.add_argument("--weight_decay", default=0.001, type=float, help="weight decay")
    parser.add_argument("--scheduler", default="[500]", type=str, help="scheduler decrease after epochs given")
    parser.add_argument("--lr_decay", default=0.1, type=float, help="Learning rate decay")
    parser.add_argument("--exp_name", default='random_experiment', type=str, help="Experiment name")
    parser.add_argument('--wandb', action='store_true', help="Use wandb")
    parser.add_argument('--gray_bkg', action='store_true', help="Train with images having gray background")
    parser.add_argument('--tiny', action='store_true', help="Train small model")
    parser.add_argument("--load_checkpoint", default='', type=str, help="name of models in folder checkpoints to load")
    parser.add_argument("--seed", default=-1, type=int, help="Random seed")

    parser.add_argument("--dataset_path", type=str, help="Path to the dataset file or folder")
    parser.add_argument("--dataset_type", type=str,  choices=['h5', 'tiff', 'nii'], help="Type of dataset")
    parser.add_argument("--im_size", type=int, default=512, help="In the case of tiff, the size of the crops to split the tiff files")
    parser.add_argument("--rescale", type=int, default=512, help="The side length of the images in the dataset")


    args = vars(parser.parse_args())
    temp = args["scheduler"].replace(" ", "").replace("[", "").replace("]", "").split(",")
    args["scheduler"] = [int(x) for x in temp]
    args["seed"] = random.randint(0, 20000) if args["seed"] == -1 else args["seed"]

    kwargs = {
        "path": args['dataset_path'],
        "lr_forward_function": lambda x: x,
        "train_transform": True,
        "to_gray": args['gray_bkg'],
        "rescale": args['rescale'],
        "normalize_range":True
    }
    if args['dataset_type'] == 'tiff':
        kwargs['im_size'] = args['im_size']
    if args['dataset_type'] == 'nii':
        kwargs['im_size'] = args['im_size']
        kwargs['clip_range'] = [3e4, 5e4]
        kwargs['file_range'] = [20, 360]
    trainSet, testSet = get_dataset(kwargs, args['dataset_type'])


    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    # device = torch.device('cpu')
    print(device)
    if args['gray_bkg']:
        model_path = f"checkpoints/ddpm_{args['exp_name']}{'_tiny' if args['tiny'] else ''}.pt"
    else:
        model_path = f"checkpoints/ddpm_{args['exp_name']}{'_tiny' if args['tiny'] else ''}.pt"

    channels = (32, 32, 32, 32, 64, 64) if args['tiny'] else (64, 64, 128, 128, 256, 256)
    model = UNet2DModel(
        sample_size=args['rescale'],  # the target image resolution
        in_channels=1,  # the number of input channels, 3 for RGB images
        out_channels=1,  # the number of output channels
        layers_per_block=2,  # how many ResNet layers to use per UNet block
        block_out_channels=channels,  # the number of output channels for each UNet block
        down_block_types=(
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",# a regular ResNet downsampling block
            "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
            "DownBlock2D",
        ),
        up_block_types=(
            "UpBlock2D",  # a regular ResNet upsampling block
            "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
        ),
    ).to(device)

    if args['load_checkpoint'] != "":
        try:
            load_model(model, f"{args['load_checkpoint']}")
        except:
            print("model not found, initializing randomly")

    noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

    exp = PyTorchExperiment(
        train_dataset=trainSet,
        test_dataset=testSet,
        batch_size=args['batch_size'],
        model=model,
        loss_fn=diffusion_loss(),
        checkpoint_path=model_path,
        experiment_name=args['exp_name'],
        with_wandb=args['wandb'],
        num_workers=torch.get_num_threads() if torch.cuda.is_available() else 0,
        seed=args["seed"],
        args=args,
        save_always=True
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=args['learning_rate'])

    num_epochs = args['epochs']
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=500,
        num_training_steps=(len(trainSet) * num_epochs),
    )

    exp.train(args['epochs'], optimizer, milestones=args['scheduler'], gamma=args['lr_decay'], scheduler=lr_scheduler)








