import sys
import os

import torch.nn
import wandb
from chip.datasets.synthetic_to_real_dataset import Synthetic2Real

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from argparse import ArgumentParser
import random
import torch
from pytorch_base.experiment import PyTorchExperiment
from pytorch_base.base_loss import BaseLoss
from chip.datasets.superres_dataset import TomogramsMAT
from chip.models.unet import build_unet
from chip.models.Pix2Pix import GAN
from chip.models.iterative_model import TomographicReconstruction
import h5py
import hdf5plugin
from torch.utils.data import Subset

import lovely_tensors as lt

lt.monkey_patch()

if __name__ == '__main__':
    parser = ArgumentParser(description="PyTorch experiments")
    parser.add_argument("--batch_size", default=5, type=int, help="batch size of every process")
    parser.add_argument("--epochs", default=25, type=int, help="number of epochs to train")
    parser.add_argument("--learning_rate", default=0.001, type=float, help="learning rate")
    parser.add_argument("--weight_decay", default=0., type=float, help="weight decay")
    parser.add_argument("--l1_lambda", default=100, type=float, help="L1_LAMBDA")
    parser.add_argument("--scheduler", default="[18,22]", type=str, help="scheduler decrease after epochs given")
    parser.add_argument("--lr_decay", default=0.1, type=float, help="Learning rate decay")
    parser.add_argument("--wandb_exp_name", default='random_experiments', type=str, help="Experiment name in wandb")
    parser.add_argument('--wandb', action='store_true', help="Use wandb")
    parser.add_argument('--with_gan_loss', action='store_true', help="Use gan loss when training")
    parser.add_argument("--load_checkpoint", default='', type=str, help="name of model in folder checkpoints to load")
    parser.add_argument("--seed", default=-1, type=int, help="Random seed")
    args = vars(parser.parse_args())
    temp = args["scheduler"].replace(" ", "").replace("[", "").replace("]", "").split(",")
    args["scheduler"] = [int(x) for x in temp]
    args["seed"] = random.randint(0, 20000) if args["seed"] == -1 else args["seed"]

    DATA_PATH = "/mydata/chip/shared/ra.psi.ch/p17299/data_for_SDSC_v0" if torch.cuda.is_available() else "../../data/p17299"
    h5filepath = f"{DATA_PATH}/tomogram_delta.mat" if torch.cuda.is_available() else f"{DATA_PATH}/tomogram_delta_v0.mat"
    mat_file = h5py.File(h5filepath, "r")
    tomogram = mat_file.get('tomogram_delta')
    ds = Synthetic2Real(tomogram)
    ds = Subset(ds, torch.arange(23, 181))

    trainSet = Subset(ds, range(0, round(0.9 * len(ds))))
    testSet = Subset(ds, range(round(0.9 * len(ds)), len(ds)))

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    print(device)

    # Using first the forward model for the prior
    model = GAN(1).to(device)


    def load_model(model, model_path):
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"model loaded from checkpoint {model_path}")


    project_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
    model_path = f"{project_path}/checkpoints/gan_model.pt"

    if args['load_checkpoint'] != "":
        load_model(model, f"{project_path}/checkpoints/{args['load_checkpoint']}")


    class GANLoss(BaseLoss):
        def __init__(self, optimizer_disc, L1_LAMBDA=100):
            stats_names = [
                "D_real_loss",
                "D_fake_loss", "G_fake_loss", "L1"
            ]
            super(GANLoss, self).__init__(stats_names)
            self.optimizer_disc = optimizer_disc
            self.L1_LAMBDA = L1_LAMBDA

        def log_epoch_summary(self, instance, model, epoch):
            if wandb.run:
                x, y, _ = instance
                x, y = x.to(device), y.to(device)
                netG = model.unet
                y_fake = netG(x, torch.LongTensor([0]).to(device), return_dict=False)[0]
                images_fake = y_fake[:2].transpose(0, 1).reshape(y_fake.shape[-1], -1)
                images_real = y[:2].transpose(0, 1).reshape(y.shape[-1], -1)
                summary_image = torch.cat([images_fake, images_real], 0)
                wandb.log({"epoch": epoch, "sample_image": [wandb.Image(summary_image, caption=f"Summary epoch {epoch}")]})


        def compute_loss(self, instance, model: GAN):
            x, y = instance
            x, y = x.unsqueeze(1).to(device), y.unsqueeze(1).to(device)

            BCE_Loss = torch.nn.BCEWithLogitsLoss()
            L1_Loss = torch.nn.L1Loss()

            model.zero_grad()

            netG = model.unet
            netD = model.discriminator

            ############## Train Discriminator ##############
            y_fake = netG(x, torch.LongTensor([0]).to(device), return_dict=False)[0]


            if args['with_gan_loss']:
                D_real = netD(x, y)
                D_real_loss = BCE_Loss(D_real, torch.ones_like(D_real))
                D_fake = netD(x, y_fake.detach())
                D_fake_loss = BCE_Loss(D_fake, torch.zeros_like(D_fake))
                D_loss = (D_real_loss + D_fake_loss) / 2

                if torch.is_grad_enabled():
                    netD.zero_grad()
                    D_loss.backward()
                    self.optimizer_disc.step()

            ############## Train Generator ##############
            D_fake = netD(x, y_fake)
            G_fake_loss = BCE_Loss(D_fake, torch.ones_like(D_fake))
            L1 = L1_Loss(y_fake, y) * self.L1_LAMBDA
            if args['with_gan_loss']:
                G_loss = G_fake_loss + L1
            else:
                G_loss = L1
                D_real_loss = D_fake_loss = torch.zeros_like(G_loss)

            return G_loss, {
                "D_real_loss": D_real_loss, "D_fake_loss": D_fake_loss,
                "G_fake_loss": G_fake_loss, "L1": L1 / self.L1_LAMBDA
            }


    optimizer_gen = torch.optim.AdamW(
        model.unet.parameters(),
        lr=args['learning_rate'],
        weight_decay=args['weight_decay'])

    optimizer_disc = torch.optim.AdamW(
        model.discriminator.parameters(),
        lr=args['learning_rate'],
        weight_decay=args['weight_decay'])

    exp = PyTorchExperiment(
        args=args,
        train_dataset=trainSet,
        test_dataset=testSet,
        batch_size=args['batch_size'],
        model=model,
        loss_fn=GANLoss(optimizer_disc, L1_LAMBDA=args['l1_lambda']),
        checkpoint_path=model_path,
        experiment_name=args['wandb_exp_name'],
        with_wandb=args['wandb'],
        num_workers=os.cpu_count() if torch.cuda.is_available() else 0,
        seed=args["seed"],
        loss_to_track="L1"
    )

    exp.train(args['epochs'], optimizer_gen, milestones=args['scheduler'], gamma=args['lr_decay'])
