import argparse
import os
import torch
from utilities.builder import build_proc_from_run_dir
from config import get_dataset_cfg
from baselines.model_factory import make_config
from exp.exp_basic import ExpConfigs, Exp_Basic
from exp.exp_autoregressive import Exp_Dynamic_Autoregressive


def main():
    parser = argparse.ArgumentParser(
        description="Command-line evaluation for baseline models."
    )
    parser.add_argument("--model_path", type=str)
    parser.add_argument("--gpu", type=int, default=0, help="GPU id; ignored if no CUDA.")
    parser.add_argument("--dataset", type=str, default="ns_1e-3")
    parser.add_argument("--seq_id", type=int, default=0)
    parser.add_argument("--rollout_steps", type=int, default=15)
    parser.add_argument("--eval_long_term", action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument("--lt_steps", type=int, default=10)
    parser.add_argument("--lt_bs", type=int, default=4)

    args = parser.parse_args()
    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    dataset_cfg = get_dataset_cfg(name=args.dataset)
    data_path = dataset_cfg.DATA_PATH
    proc = build_proc_from_run_dir(run_dir=args.model_path, dataset=args.dataset)

    cfgs = Exp_Basic.load_all_configs(args.model_path)
    model_cfg_dict  = cfgs["model_cfg"]
    model_cfg = make_config(**model_cfg_dict)
    exp_cfg_dict    = cfgs["exp_cfg"] 
    exp_cfg_dict = dict(exp_cfg_dict)
    if isinstance(exp_cfg_dict.get("device"), str):
        exp_cfg_dict["device"] = device
    exp_cfg = ExpConfigs(**exp_cfg_dict)

    print(exp_cfg)
    print(model_cfg)

    exp = Exp_Dynamic_Autoregressive(args=None, exp_cfg=exp_cfg, model_cfg=model_cfg, data_processor=proc)
    exp.load_from_ckpt(ckpt_path=os.path.join(args.model_path, "model_tr_best.pth"), device=str(exp_cfg.device))
    out_dir = os.path.join(args.model_path, "vis")
    os.makedirs(out_dir, exist_ok=True)
    exp.visualize_random_rollout(
        group="test", batch_size=2, rollout_steps=args.rollout_steps,
        out_dir=os.path.join(out_dir, "test"), mode="all"  # "last" or "all"
    )
    """exp.visualize_random_rollout(
        group="train_eval", batch_size=2, rollout_steps=15,
        out_dir=os.path.join(out_dir, "train"), mode="last"
    )"""
    exp.visualize_random_rollout(
        group="train_eval", batch_size=2, rollout_steps=args.rollout_steps,
        out_dir=os.path.join(out_dir, "train"), mode="all"
    )
    exp.visualize_rollout_by_index(
        group="test", seq_index=args.seq_id, rollout_steps=args.rollout_steps, out_dir=os.path.join(out_dir, "test/fixed_sample"), mode="all"
    )

    if args.eval_long_term:
        out_dir_lt = os.path.join(args.model_path, "long_term")
        """exp.evaluate_long_trajs(out_dir=os.path.join(out_dir_lt, "train_full"), group="train", 
                                rollout_steps=args.lt_steps, batch_size=args.lt_bs,
                                save_pt=False, save_png=False,)"""
        summary = exp.evaluate_long_trajs(out_dir=os.path.join(out_dir_lt, "test_full"), group="test", 
                                rollout_steps=args.lt_steps, batch_size=args.lt_bs,
                                save_pt=False, save_png=False,)
        print(summary)
        exp.evaluate_long_by_indices(out_dir=os.path.join(out_dir_lt, "selected"), group="train", rollout_steps=args.lt_steps,
                                     indices=[0, 1, 2], save_pt=True, save_png=True,)


if __name__ == "__main__":
    main()