"""
source init_env.sh
python scripts/run.py --dryrun

or with interactive

WANDB_CONSOLE="wrap" python scripts/run.py --dryrun

WANDB_CONSOLE="wrap" python scripts/run.py --config="config/wikisection.json"  --dryrun
"""

import os
import platform
import wandb
import getpass
from copy import deepcopy
import random, torch, numpy
from src import system
from src.utils import load_json
from src.setup import process_config
# from pytorch_lightning.callbacks import LearningRateMonitor
import pytorch_lightning as pl

torch.backends.cudnn.benchmark = True

SYSTEM = {
    'TCLSystem': system.TCLSystem,
    'WikiTCLSystem': system.WikiTCLSystem,
    'BinaryWikiSystem': system.BinaryWikiSystem,
    'DeltaBinaryWikiSystem': system.BinaryDeltaWikiSystem,
}

stepDict = {1: [int(5e3), int(5e3)],
            2: [int(1e4), int(1e4)],
            3: [int(1e4), int(1e4)],
            4: [int(1e4), int(1e4)],
            5: [int(1e4), int(1e4)]}


def run(args):
    if args.dryrun:
        print("Running in dryrun mode")
        os.environ['WANDB_MODE'] = 'dryrun'

    config_path = args.config
    config = process_config(config_path, args=args)
    seed_everything(config.seed, use_cuda=config.cuda)

    if args.p is not None:
        config.data_params.p = float(args.p)

    num_data = config.data_params.n_obs_per_seg * config.data_params.n_segments
    num_steps_per_epoch = int(num_data/config.optim_params.batch_size)
    if num_steps_per_epoch:
        config.num_epochs_init = int(config.num_steps_init/float(num_steps_per_epoch))
        config.num_epochs = int(config.num_steps/float(num_steps_per_epoch))
    else: # wikisection
        config.num_epochs_init = 500
        config.num_epochs = 1000

    config.exp_dir = 'experiments'
    config.model_params.language_encoder = args.encoder

    if args.exp_name:
        config.exp_name = args.exp_name

    wandb.init(
        project=args.project,
        entity=getpass.getuser(),
        # name=config.exp_name,
        config=config,
    )

    SystemClass = SYSTEM[config.system]
    system = SystemClass(config)

    save_directory = os.path.join(JUICE_DIR, config.exp_name)
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)
    print("Save directory is {}".format(save_directory))

    ## Phase 1: Train only prediction head
    trainer = pl.Trainer(
        default_root_dir=config.exp_dir,
        gpus=1,
        max_epochs=int(config.num_epochs_init),
        min_epochs=int(config.num_epochs_init),
    )
    trainer.fit(system)

    ## Phase 2: train everything. Reset LR rescheduler. Train everything.
    system.set_to_train()
    system.num_train_step = 0
    trainer = pl.Trainer(
        default_root_dir=config.exp_dir,
        gpus=1,
        max_epochs=int(config.num_epochs), # NOTE different from above
        min_epochs=int(config.num_epochs), # NOTE
    )
    trainer.fit(system)

    ## Save the model
    system.save(directory=save_directory)

    ## Evaluation:
    trainer.test(system)


def seed_everything(seed, use_cuda=True):
    random.seed(seed)
    torch.manual_seed(seed)
    if use_cuda: torch.cuda.manual_seed_all(seed)
    numpy.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default="./config/tcl.json",
                        help='path to config file')
    parser.add_argument('--dryrun', default=False, action='store_true')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--p', default=None)
    parser.add_argument('--exp-name', type=str, required=True)
    parser.add_argument('--project', default="unsupervised", type=str)
    parser.add_argument('--encoder', default="GPT2", type=str)
    args = parser.parse_args()

    run(args)



