import argparse
import os
import sys
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

from TSPD.datasets import get_dataset
from TSPD.utils import get_transform, set_random_seed
from TSPD.setup_cfg import setup_cfg, print_args
from TSPD.TSPD import TSPD



def run_exp(cfg):
    device = [int(s) for s in cfg.gpu_id.split(',')]

    cfg.use_validation = cfg.use_validation

    train_dataset, classes_names, templates = get_dataset(cfg, split='train', transforms=get_transform(cfg))
    val_dataset, _, _ = get_dataset(cfg, split='val', transforms=get_transform(cfg))
    eval_dataset, _, _ = get_dataset(cfg, split='test', transforms=get_transform(cfg))
    cfg.nb_task = len(eval_dataset)

    load_file = cfg.load_file if cfg.load_file else None
    trainer = TSPD(cfg, device, classes_names, templates, load_file=load_file)

    datasets = {'train': train_dataset, 'val': val_dataset, 'test': eval_dataset}
    trainer.train_and_eval(cfg, datasets)


def main(args):
    cfg = setup_cfg(args)
    cfg.command = ' '.join(sys.argv)
    print("The learning rate for per task is: {}".format(cfg.TSPD.optim.lr))
    cfg.log_path = os.path.join('experiments',"{}_{}_{}".format(cfg.threshold, cfg.energy, args.tips))
    if not os.path.exists(cfg.log_path):
        os.makedirs(cfg.log_path)
    with open(os.path.join(cfg.log_path, 'config.yaml'), 'w') as f: 
        f.write(cfg.dump())
    print_args(args, cfg)
    set_random_seed(cfg.seed)
    run_exp(cfg)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config-path", type=str, default="configs/MTIL.yaml", help="path to config")
    parser.add_argument("--gpu_id", type=str, default="3", help="gpu id")
    parser.add_argument("--tips", type=str, default="aaa", help="")
    parser.add_argument("--temperature", type=float, default=1.0, help="gumbel temperature")
    parser.add_argument("--threshold", type=float, default=0.96, help="")
    parser.add_argument("--energy", type=float, default=0.95, help="")
    parser.add_argument("--lr", type=str, default="0.9,2.9,1.8,1.0,0.9,1.6,0.5,0.9,1.6,0.8,0.5", help="")
    parser.add_argument("--gumbel_lr", type=float, default=2.0, help="")
    parser.add_argument(
        "opts",
        default=None,
        nargs=argparse.REMAINDER,
        help="modify config options using the command-line",
    )
    args = parser.parse_args()
    main(args)