import json
import os

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from arguments import parser
from datamodules.dmc_datamodule import DMCMultiDomainDataModule
from datamodules.dummy_datamodule import RandomDataset
from models.multitask_dt_module import MultiTaskDTLitModule

ICLR_ANON_MODEL_DIR = "iclr_anon/outputs_dmc/"
ICLR_ANON_STAT_DIR = "iclr_anon/results_dmc/"


def main(args):

    # set seed for reproducibility, although the trainer does not allow deterministic for this implementation
    pl.seed_everything(args.seed, workers=True)

    with open(args.multi_config) as jsonfile:
        args.source_envs = json.load(jsonfile)

    if args.source_data_type == "full":
        args.dataset_types = ["fullcollect"]
    elif args.source_data_type == "rand":
        args.dataset_types = ["randcollect"]
    elif args.source_data_type == "mix":
        args.dataset_types = ["fullcollect", "randcollect"]
        args.num_steps = args.num_steps // 2

    # init data module
    dmc_data = DMCMultiDomainDataModule.from_argparse_args(args)

    # init training module
    dict_args = vars(args)
    if args.load_model_from:
        model = MultiTaskDTLitModule.load_from_checkpoint(args.load_model_from, **dict_args)
        print("loaded model from", args.load_model_from)
    else:
        model = MultiTaskDTLitModule(**dict_args)

    # init root dir
    os.makedirs(args.output_dir, exist_ok=True)
    print("output dir", args.output_dir)

    # checkpoint saving metrics
    checkpoint_callback = ModelCheckpoint(
        dirpath=os.path.join(args.output_dir, "checkpoints"),
        # filename="best_model",
        filename="checkpoint_{epoch:02d}",
        mode="min",
        save_top_k=args.save_k,
        monitor="val/avg_loss",
        save_last=True,
    )

    if args.ICLR_ANON:
        print(
            "logging to",
            os.path.join(os.environ["ICLR_ANON_OUTPUT_DIR"], "tb_logs"),
            args.exp_name,
        )
        logger = TensorBoardLogger(os.path.join(os.environ["ICLR_ANON_OUTPUT_DIR"], "tb_logs"), name=args.exp_name)
    else:
        logger = TensorBoardLogger(os.path.join(args.output_dir, "tb_logs"), name="train")

    # init trainer
    trainer = pl.Trainer(
        accelerator=args.accelerator,
        devices=args.devices,
        num_nodes=args.nodes,
        default_root_dir=args.output_dir,
        min_epochs=1,
        max_epochs=args.epochs,
        callbacks=[checkpoint_callback],
        strategy="ddp",
        fast_dev_run=False,
        logger=logger,
    )

    # start training
    trainer.fit(model, datamodule=dmc_data)

    # testing
    dummy = RandomDataset()
    trainer.test(model, dummy)


if __name__ == "__main__":

    parser = DMCMultiDomainDataModule.add_argparse_args(parser)
    parser = MultiTaskDTLitModule.add_model_specific_args(parser)

    args = parser.parse_args()

    if args.reset_path:
        ICLR_ANON_MODEL_DIR = args.reset_path

    if args.ICLR_ANON:
        args.data_dir_prefix = os.environ["ICLR_ANON_DATA_DIR"]
        if not args.exp_name:
            args.exp_name = os.environ["ICLR_ANON_EXPERIMENT_NAME"] + "-" + os.environ["ICLR_ANON_JOB_NAME"]
        args.output_dir = os.path.join(ICLR_ANON_MODEL_DIR, args.exp_name)
        print("===== data dir is", args.data_dir_prefix)
        print("===== save dir is", args.output_dir)

    args.dmc = True
    args.timesteps = 250

    main(args)
