# Modified by https://github.com/zhaoshan2/pangu-pytorch
#
# Original repository license: Apache License 2.0
# Original author: zhaoshan2
# These changes were made for the purpose of integrating PINNs into Transformer-based ocean modeling.

import sys
sys.path.append("/here/is/code/M2F-PINN")
from util_data import utils, utils_data
from util_data.utils_dist import get_dist_info, init_dist
from util_data.config import cfg
from models.var2_model import PanguModel
import torch
import os
from torch.utils import data
from models.var2_sample_fourier import test, train
import argparse
import time
import logging
from tensorboardX import SummaryWriter
from torch.utils.data.distributed import DistributedSampler
from accelerate import Accelerator


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--type_net', type=str, default="train_M2Fpinn")#var2_TrainFF_1e5optlr
    parser.add_argument('--load_my_best', type=bool, default=True)
    parser.add_argument('--launcher', default='pytorch', help='job launcher')
    parser.add_argument('--local-rank', type=int, default=0)
    parser.add_argument('--dist', default=False)

    args = parser.parse_args()
    starts = time.time()

    PATH = cfg.PG_INPUT_PATH

    opt = {"gpu_ids": [0, 1]}
    gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
    accelerator = Accelerator()
    device = accelerator.device
    # device = torch.device('cpu')


    # ----------------------------------------
    # distributed settings
    # ----------------------------------------
    if args.dist:
        init_dist('pytorch')
    rank, world_size = get_dist_info()
    print("The rank and world size is", rank, world_size)
    if rank == 0:
        print(f"Predicting on {device}")


    output_path = os.path.join(cfg.PG_OUT_PATH, args.type_net, str(cfg.PG.HORIZON))
    utils.mkdirs(output_path)

    writer_path = os.path.join(output_path, "writer")
    if not os.path.exists(writer_path):
        os.mkdir(writer_path)

    writer = SummaryWriter(writer_path)

    logger_name = args.type_net + str(cfg.PG.HORIZON)
    utils.logger_info(logger_name, os.path.join(output_path, logger_name + '.log'))

    logger = logging.getLogger(logger_name)


    train_dataset = utils_data.NetCDFDataset(nc_path=PATH,
                                data_transform=None,
                                training=True,
                                validation = False,
                                startDate = cfg.PG.TRAIN.START_TIME,
                                endDate= cfg.PG.TRAIN.END_TIME,
                                freq=cfg.PG.TRAIN.FREQUENCY,
                                horizon=cfg.PG.HORIZON)
    if args.dist:
        train_sampler = DistributedSampler(train_dataset, shuffle=True, drop_last=True)

        train_dataloader = data.DataLoader(dataset=train_dataset, batch_size=cfg.PG.TRAIN.BATCH_SIZE//len(opt['gpu_ids']),
                                            num_workers=4, pin_memory=True, sampler=train_sampler)
    else:
        train_dataloader = data.DataLoader(dataset=train_dataset, #persistent_workers=True,# prefetch_factor=2,
                                           batch_size=cfg.PG.TRAIN.BATCH_SIZE//len(opt['gpu_ids']),
                                           drop_last=True, shuffle=True, num_workers=6, pin_memory=True)

    dataset_length =len(train_dataloader)
    if rank == 0:
        print("dataset_length", dataset_length)

    val_dataset = utils_data.NetCDFDataset(nc_path=PATH,
                               data_transform=None,
                               training=False,
                               validation = True,
                               startDate = cfg.PG.VAL.START_TIME,
                               endDate= cfg.PG.VAL.END_TIME,
                               freq=cfg.PG.VAL.FREQUENCY,
                               horizon=cfg.PG.HORIZON)
    if args.dist:
        val_sampler = DistributedSampler(val_dataset, shuffle=False, drop_last=True)

        val_dataloader = data.DataLoader(dataset=val_dataset,
                                           batch_size=cfg.PG.VAL.BATCH_SIZE // 2,
                                           num_workers=4, pin_memory=True, sampler=val_sampler)
    else:
        val_dataloader = data.DataLoader(dataset=val_dataset, #persistent_workers=True,# prefetch_factor=2,
                                           batch_size=cfg.PG.VAL.BATCH_SIZE//len(opt['gpu_ids']),
                                           drop_last=True, shuffle=False, num_workers=6, pin_memory=True)

    test_dataset = utils_data.NetCDFDataset(nc_path=PATH,
                                       data_transform=None,
                                       training=False,
                                       validation=False,
                                       startDate=cfg.PG.TEST.START_TIME,
                                       endDate=cfg.PG.TEST.END_TIME,
                                       freq=cfg.PG.TEST.FREQUENCY,
                                       horizon=cfg.PG.HORIZON)

    if args.dist:
        test_sampler = DistributedSampler(test_dataset, shuffle=False, drop_last=True)

        test_dataloader = data.DataLoader(dataset=test_dataset,
                                           batch_size=cfg.PG.TEST.BATCH_SIZE // len(opt['gpu_ids']),
                                           num_workers=9, pin_memory=True, sampler=test_sampler)
    else:
        test_dataloader = data.DataLoader(dataset=test_dataset,
                                           batch_size=cfg.PG.TEST.BATCH_SIZE,
                                           drop_last=True, shuffle=False, num_workers=6, pin_memory=True, persistent_workers=True)

    model = PanguModel(device=device).to(device)

    for param in model.parameters():
        param.requires_grad = True

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = cfg.PG.TRAIN.LR, weight_decay= cfg.PG.TRAIN.WEIGHT_DECAY)

    if rank == 0:
        msg = '\n'
        msg += utils.torch_summarize(model, show_weights=False)
        logger.info(msg)

    if rank == 0:
        print("Starting training from scratch!")
    torch.set_num_threads(cfg.GLOBAL.NUM_STREADS)

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 200, eta_min=0, last_epoch=- 1, verbose=True)
    start_epoch = 1
    model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(model, optimizer, train_dataloader, val_dataloader,lr_scheduler)

    model = train(model,accelerator, train_loader=train_dataloader,
                     val_loader=val_dataloader,
                     optimizer=optimizer,
                     lr_scheduler=lr_scheduler,
                     res_path = output_path,
                     device=device,
                     writer=writer, logger = logger, start_epoch=start_epoch)


    if args.load_my_best:

        best_model = model

    logger.info("Begin testing...")

    test(test_loader=test_dataloader,
         model=best_model,
         device=device,
         res_path=output_path)