import csv
import os
import time

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from arguments import parser
from datamodules.dmc_datamodule import DMCBCDataModule, DMCDataModule
from datamodules.dummy_datamodule import RandomDataset
from models.dt_module import DTLitModule

ICLR_ANON_MODEL_DIR = "iclr_anon/outputs_dmc/"
ICLR_ANON_STAT_DIR = "iclr_anon/results_dmc/"


def pipeline(args):
    # set seed for reproducibility, although the trainer does not allow deterministic for this implementation
    pl.seed_everything(args.seed, workers=True)

    ##############################################
    ## pretrain
    ##############################################

    # init data module
    pretrain_atari = DMCDataModule.from_argparse_args(args)

    # init training module
    dict_args = vars(args)
    pretrain_model = DTLitModule(**dict_args)

    # init root dir
    os.makedirs(args.output_dir, exist_ok=True)
    print("output dir", args.output_dir)

    # checkpoint saving metrics
    pretrain_checkpoint_callback = ModelCheckpoint(
        dirpath=os.path.join(args.output_dir, "checkpoints_pretrain"),
        filename="best_model",
        mode="min",
        save_top_k=1,
        monitor="val/avg_loss",
        save_last=True,
    )

    # loggers
    if args.ICLR_ANON:
        print(
            "logging to",
            os.path.join(os.environ["ICLR_ANON_OUTPUT_DIR"], "tb_logs"),
            "pretrain",
        )
        pretrain_logger = TensorBoardLogger(os.path.join(os.environ["ICLR_ANON_OUTPUT_DIR"], "tb_logs"), name="pretrain")
    else:
        pretrain_logger = TensorBoardLogger(os.path.join(args.output_dir, "tb_logs"), name="pretrain")

    # init trainer
    pretrain_trainer = pl.Trainer(
        accelerator=args.accelerator,
        devices=args.devices,
        num_nodes=args.nodes,
        default_root_dir=args.output_dir,
        min_epochs=1,
        max_epochs=args.epochs,
        callbacks=[pretrain_checkpoint_callback],
        strategy="ddp",
        # strategy='ddp_find_unused_parameters_false',
        fast_dev_run=False,
        logger=pretrain_logger,
    )

    # start pre-training
    pretrain_trainer.fit(pretrain_model, datamodule=pretrain_atari)

    ##############################################
    ## finetune
    ##############################################

    # init finetune bc datamodule and models
    bc_atari = DMCBCDataModule.from_argparse_args(args)
    dict_args["bc"] = True
    dict_args["forward"] = False
    dict_args["reward"] = False
    dict_args["inverse"] = False
    dict_args["rand_inverse"] = False
    dict_args["unsupervise"] = False
    dict_args["model_type"] = "naive"
    dict_args["train_replay_id"] = dict_args["val_replay_id"]
    dict_args["val_replay_id"] = (dict_args["train_replay_id"] + 1) % 5

    print("new args", dict_args)

    time.sleep(10)

    bc_model = DTLitModule.load_from_checkpoint(
        os.path.join(args.output_dir, "checkpoints_pretrain/best_model.ckpt"), **dict_args
    )
    # bc_model = DTLitModule(**dict_args)
    # bc_model.load_my_checkpoint("checkpoints_pretrain/best_model.ckpt" + args.load_model_from, no_action=args.no_load_action)
    # print("loaded model from", args.load_model_from)

    bc_checkpoint_callback = ModelCheckpoint(
        dirpath=os.path.join(args.output_dir, "checkpoints_bc"),
        filename="best_reward_model",
        mode="max",
        save_top_k=args.save_k,
        monitor="val/interactive_reward",
        save_last=True,
    )

    # loggers
    if args.ICLR_ANON:
        print(
            "logging to",
            os.path.join(os.environ["ICLR_ANON_OUTPUT_DIR"], "tb_logs"),
            "finetune",
        )
        bc_logger = TensorBoardLogger(os.path.join(os.environ["ICLR_ANON_OUTPUT_DIR"], "tb_logs"), name="finetune")
    else:
        bc_logger = TensorBoardLogger(os.path.join(args.output_dir, "tb_logs"), name="finetune")

    # init trainer
    bc_trainer = pl.Trainer(
        accelerator=args.accelerator,
        devices=args.devices,
        num_nodes=args.nodes,
        default_root_dir=args.output_dir,
        min_epochs=1,
        max_epochs=30,
        callbacks=[bc_checkpoint_callback],
        strategy="ddp",
        # strategy='ddp_find_unused_parameters_false',
        fast_dev_run=False,
        logger=bc_logger,
    )

    # start finetuneing
    bc_trainer.fit(bc_model, datamodule=bc_atari)

    ##############################################
    ## evaluation
    ##############################################

    # testing
    dummy = RandomDataset()
    bc_trainer.test(bc_model, dummy, ckpt_path="best")

    # summarize results from all saved bc models
    # if args.ICLR_ANON:
    #     args.stat_file = os.environ["ICLR_ANON_EXPERIMENT_NAME"] + "-" + os.environ["ICLR_ANON_JOB_NAME"]
    # stat(os.path.join(args.output_dir, 'checkpoints'), args)


def test(args):

    # set seed for reproducibility, although the trainer does not allow deterministic for this implementation
    pl.seed_everything(args.seed, workers=True)

    # init training module
    dict_args = vars(args)
    if args.load_model_from:
        model = DTLitModule.load_from_checkpoint(args.load_model_from, **dict_args)
        print("loaded model from", args.load_model_from)
    else:
        model = DTLitModule(**dict_args)

    # init root dir
    os.makedirs(args.output_dir, exist_ok=True)
    print("output dir", args.output_dir)

    # init trainer
    trainer = pl.Trainer(
        accelerator=args.accelerator,
        devices=args.devices,
        num_nodes=args.nodes,
        default_root_dir=args.output_dir,
        min_epochs=0,
        max_epochs=args.epochs,
        strategy="ddp",
        # strategy='ddp_find_unused_parameters_false',
        fast_dev_run=False,
    )

    # testing
    dummy = RandomDataset()
    trainer.test(model, dummy)


def main(args):

    # set seed for reproducibility, although the trainer does not allow deterministic for this implementation
    pl.seed_everything(args.seed, workers=True)

    # init data module
    if args.bc:
        dmc_data = DMCBCDataModule.from_argparse_args(args)
    else:
        dmc_data = DMCDataModule.from_argparse_args(args)

    # init training module
    dict_args = vars(args)

    if not args.bc and args.load_model_from:
        model = DTLitModule.load_from_checkpoint(args.load_model_from, strict=not args.no_strict, **dict_args)
        print("loaded model from", args.load_model_from)
    else:
        model = DTLitModule(**dict_args)

    if args.bc and args.load_model_from:
        model.load_my_checkpoint(
            args.load_model_from,
            no_action=args.no_load_action,
            no_action_head=args.no_action_head,
        )
        print("loaded model from", args.load_model_from)

    # init root dir
    os.makedirs(args.output_dir, exist_ok=True)
    print("output dir", args.output_dir)

    # checkpoint saving metrics
    if args.bc:
        checkpoint_callback = ModelCheckpoint(
            dirpath=os.path.join(args.output_dir, "checkpoints_bc"),
            filename="best_reward_model",
            mode="max",
            save_top_k=args.save_k,
            monitor="val/interactive_reward",
            save_last=False,
        )
    else:
        checkpoint_callback = ModelCheckpoint(
            dirpath=os.path.join(args.output_dir, "checkpoints_pretrain"),
            filename="best_model",
            mode="min",
            save_top_k=1,
            monitor="train/avg_loss",
            save_last=True,
        )

    if args.ICLR_ANON:
        print(
            "logging to",
            os.path.join(os.environ["ICLR_ANON_OUTPUT_DIR"], "tb_logs"),
            args.exp_name,
        )
        logger = TensorBoardLogger(os.path.join(os.environ["ICLR_ANON_OUTPUT_DIR"], "tb_logs"), name=args.exp_name)
    else:
        logger = TensorBoardLogger(os.path.join(args.output_dir, "tb_logs"), name="train")

    # init trainer
    trainer = pl.Trainer(
        accelerator=args.accelerator,
        devices=args.devices,
        num_nodes=args.nodes,
        default_root_dir=args.output_dir,
        min_epochs=1,
        max_epochs=args.epochs,
        callbacks=[checkpoint_callback],
        strategy="ddp",
        # strategy='ddp_find_unused_parameters_false',
        fast_dev_run=False,
        logger=logger,
    )

    if not args.test:
        # start training
        trainer.fit(model, datamodule=dmc_data)

    # testing
    dummy = RandomDataset()
    trainer.test(model, dummy, ckpt_path="best")


def stat(checkpoint_folder, args):
    """test all the saved models and save to csv file."""
    if args.ICLR_ANON:
        save_file = os.path.join(ICLR_ANON_STAT_DIR, args.stat_file)
    else:
        save_file = os.path.join(args.output_dir, args.stat_file)

    def get_stat_in_folder(checkpoint_folder, args):
        stats = []
        model_list = os.listdir(checkpoint_folder)
        for model_name in model_list:
            if "best_reward" in model_name:
                bc_model = DTLitModule.load_from_checkpoint(os.path.join(checkpoint_folder, model_name))
                reward, std = bc_model.get_return_dmc(args.eval_epochs)
                print(checkpoint_folder, model_name, "reward", reward, "+-", std)
                stats.append(
                    [
                        bc_model.hparams.domain,
                        bc_model.hparams.task,
                        checkpoint_folder,
                        model_name,
                        str(reward),
                        str(std),
                    ]
                )

        if os.path.exists(save_file):
            with open(save_file, "a+", newline="") as f:
                writer = csv.writer(f)
                writer.writerows(stats)
        else:
            with open(save_file, "w") as f:
                writer = csv.writer(f)
                fields = ["Domain", "Task", "Exp", "Checkpoint", "Reward", "Std"]
                writer.writerow(fields)
                writer.writerows(stats)

    if checkpoint_folder:
        get_stat_in_folder(checkpoint_folder, args)

    else:
        ## collect all results from the domain and task
        if args.ICLR_ANON:
            dmc_files = os.listdir(ICLR_ANON_MODEL_DIR)
        else:
            dmc_files = args.output_dir
        print("all files", dmc_files)

        done_files = []
        if os.path.exists(save_file):
            with open(save_file) as file:
                reader = csv.reader(file)
                for row in reader:
                    done_files.append(row[2])
        print("files already read", done_files)

        for d in dmc_files:
            if args.domain + "_" + args.task in d:
                checkpoint_folder = os.path.join(os.path.join(ICLR_ANON_MODEL_DIR, d), "checkpoints")
                if checkpoint_folder in done_files:
                    print("folder already read before", checkpoint_folder)
                    continue
                print("checking out result folder", checkpoint_folder)
                try:
                    get_stat_in_folder(checkpoint_folder, args)
                except (AttributeError, FileNotFoundError) as error:
                    print("ignoring", error)


if __name__ == "__main__":
    # figure out which model to use
    parser.add_argument("--bc", default=False, action="store_true", help="behavior clone finetuning")
    parser.add_argument(
        "--pipeline",
        default=False,
        action="store_true",
        help="the whole pretrain and finetune pipeline",
    )
    parser.add_argument("--test", default=False, action="store_true", help="only test")
    parser.add_argument("--stat", default=False, action="store_true", help="only do statistics")

    # THIS LINE IS KEY TO PULL THE MODEL NAME
    temp_args, _ = parser.parse_known_args()

    # init parser
    if temp_args.pipeline:
        parser = DMCBCDataModule.add_argparse_args(parser)
        parser = DMCDataModule.add_argparse_args(parser)
    elif temp_args.bc:
        parser = DMCBCDataModule.add_argparse_args(parser)
    else:
        parser = DMCDataModule.add_argparse_args(parser)

    parser = DTLitModule.add_model_specific_args(parser)

    args = parser.parse_args()

    if args.reset_path:
        ICLR_ANON_MODEL_DIR = args.reset_path

    if args.ICLR_ANON:
        args.data_dir_prefix = os.environ["ICLR_ANON_DATA_DIR"]
        if not args.exp_name:
            args.exp_name = os.environ["ICLR_ANON_EXPERIMENT_NAME"] + "-" + os.environ["ICLR_ANON_JOB_NAME"]
        args.output_dir = os.path.join(ICLR_ANON_MODEL_DIR, args.exp_name)
        print("===== data dir is", args.data_dir_prefix)
        print("===== save dir is", args.output_dir)

    args.dmc = True
    args.timesteps = 250

    if temp_args.pipeline:
        pipeline(args)
    elif temp_args.test:
        test(args)
    elif temp_args.stat:
        stat(args.load_model_from, args)
    else:
        main(args)
