import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from argparse import ArgumentParser
import random
import torch
from pytorch_base.experiment import PyTorchExperiment
from pytorch_base.base_loss import BaseLoss
from chip.datasets.superres_dataset import SuperresolutionDS, TomogramsMAT
from chip.models.unet import build_unet
import h5py
import hdf5plugin

import lovely_tensors as lt
lt.monkey_patch()

if __name__ == '__main__':
    parser = ArgumentParser(description="PyTorch experiments")
    parser.add_argument("--batch_size", default=50, type=int, help="batch size of every process")
    parser.add_argument("--epochs", default=25, type=int, help="number of epochs to train")
    parser.add_argument("--learning_rate", default=0.001, type=float, help="learning rate")
    parser.add_argument("--weight_decay", default=0., type=float, help="weight decay")
    parser.add_argument("--scheduler", default="[18,22]", type=str, help="scheduler decrease after epochs given")
    parser.add_argument("--lr_decay", default=0.1, type=float, help="Learning rate decay")
    parser.add_argument("--wandb_exp_name", default='random_experiments', type=str, help="Experiment name in wandb")
    parser.add_argument('--wandb', action='store_true', help="Use wandb")
    parser.add_argument("--load_checkpoint", default='', type=str, help="name of model in folder checkpoints to load")
    parser.add_argument("--seed", default=-1, type=int, help="Random seed")
    args = vars(parser.parse_args())
    temp = args["scheduler"].replace(" ", "").replace("[", "").replace("]", "").split(",")
    args["scheduler"] = [int(x) for x in temp]
    args["seed"] = random.randint(0, 20000) if args["seed"] == -1 else args["seed"]

    # files = os.listdir('../../data/imgs_synthetic')
    # ds = SuperresolutionDS(files, data_path="../../data")

    DATA_PATH = "/mydata/chip/shared/ra.psi.ch/p17299/data_for_SDSC" if torch.cuda.is_available() else "../../data/p17299"
    h5filepath = f"{DATA_PATH}/tomogram_delta.mat"
    mat_file = h5py.File(h5filepath, "r")
    ds = TomogramsMAT(mat_file, crop_xoffset=138, crop_yoffset=138, crop_width=1024, normalize_range=True)
    ds = torch.utils.data.Subset(ds, range(250, 450))
    trainSet = torch.utils.data.Subset(ds, range(0, round(0.9 * len(ds))))
    testSet = torch.utils.data.Subset(ds, range(round(0.9 * len(ds)), len(ds)))

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    print(device)

    class SRLoss(BaseLoss):
        def __init__(self):
            stats_names = ["loss"]
            super(SRLoss, self).__init__(stats_names)

        def compute_loss(self, instance, model):
            ce = torch.nn.CrossEntropyLoss()
            source, target, _ = instance
            target = (target > 0.4).long()
            source, target = source.to(device), target.to(device)
            model.zero_grad()
            output = model(source.unsqueeze(1))
            loss = ce(output, target)
            return loss, {"loss": loss}


    # Using first the forward model for the prior
    model = build_unet(1, num_classes=2).to(device)
    def load_model(model, model_path):
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"model loaded from checkpoint {model_path}")


    project_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
    # model_path = f"{project_path}/checkpoints/super_resolution_unet.pt"
    model_path = f"{project_path}/checkpoints/real_data_unet.pt"

    if args['load_checkpoint'] != "":
        load_model(model, f"{project_path}/checkpoints/{args['load_checkpoint']}")

    exp = PyTorchExperiment(
        args=args,
        train_dataset=trainSet,
        test_dataset=testSet,
        batch_size=args['batch_size'],
        model=model,
        loss_fn=SRLoss(),
        checkpoint_path=model_path,
        experiment_name=args['wandb_exp_name'],
        with_wandb=args['wandb'],
        num_workers=0, #os.cpu_count() if torch.cuda.is_available() else 0,
        seed=args["seed"]
    )

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args['learning_rate'],
        weight_decay=args['weight_decay'])

    exp.train(args['epochs'], optimizer, milestones=args['scheduler'], gamma=args['lr_decay'])