from common.cfg.config_parser import get_config
from pathlib import Path

import os, datetime, argparse

from configs.paths import *
def default_opt():
    parser = argparse.ArgumentParser()
    parser.add_argument("--modelname", type=str, default="HOIBOT", help="model_name to use [ HOIBOT, STTranGaze")
    parser.add_argument("--cfg", type=str, default= project_path + "/hoi/configs/cfg/train_hyp_f0.yaml",
                        help="path to hyperparameter configs")
    parser.add_argument("--weights", type=str, default="weights", help="root folder for all pretrained weights")
    parser.add_argument("--data", type=str, default=project_path, help="dataset root path")
    parser.add_argument("--subset-train", type=int, default=-1, help="sub train dataset length")
    parser.add_argument("--subset-val", type=int, default=-1, help="sub val dataset length")
    parser.add_argument("--imgsz", "--img", "--img-size", type=int, default=224,
                        help="train, val image size (pixels)")
    parser.add_argument("--epochs", type=int, default=40, help="number of epochs")
    parser.add_argument("--warmup", type=int, default=3, help="number of warmup epochs")
    parser.add_argument("--project", default = project_path + "/hoi/runs/HOI4ABOT", help="save to project/name")
    parser.add_argument("--name", default="exp_f{}", help="save to project/name")
    parser.add_argument("--save-period", type=int, default=1, help="Save checkpoint every x epochs (disabled if < 1)")
    parser.add_argument("--disable-wandb", action="store_true", default=False, help="disable wandb logging")
    parser.add_argument("--gaze", type=str, default="concat", help="how to use gaze features: no, concat, cross, cross_all")
    parser.add_argument("--global-token", action="store_true",
                        help="Use global token, only for cross-attention mode")
    parser.add_argument("--device", type=str, default="cuda:0", help="Used device")
    parser.add_argument("--depth", type=int, default=-1, help="Change the depth to the transformer model")


    opt = parser.parse_args()
    return opt

def cfg_to_info(opt):
    """
    :param opt: parsed from the arguments.
    :return:
    dictionary
        - PATHS
        - MODEL
        - TRAINER
            - DATASET
                - AUGMENTATION
            - ENGINE
                - OPTIMIZER
                - LOSS
                - EARLY_STOP
    """
    info = {}
    date_str = datetime.datetime.now().strftime('%y-%m-%d-%H%M')
    opt.project = os.path.join(opt.project, date_str)
    info["LOGGER"] = {"DATE_STR": date_str}

    output_path = Path(opt.project)
    run_id = opt.name
    weights_path = Path(opt.weights)
    info["PATHS"] = {
        "annotations": annotations_dir,
        "weight_path": weights_path,
        "dataset_path": Path(opt.data),
        "backbone_model_path": weights_path / "backbone" / "resnet101-ag.pth",
        "sttran_word_vector_dir": weights_path / "semantic",
        "output_path": output_path,
        "run_id": run_id,
        "log_weight_path": output_path / run_id / "weights",
        "log_file_path": output_path / run_id / "train.log",
        "log_csv_path": output_path / run_id / "metrics.csv",
        "log_triplet_ap_path": output_path / run_id / "class_ap.csv",
        "log_clip_losses_path": output_path / run_id / "train_losses.json",
        "log_eval_path": output_path / run_id / "eval",
        "log_eval_file_path": output_path / run_id / "eval/eval.log",
        "eval_result_path": output_path / run_id / "eval/all_results.json",

    }
    info["PATHS"]["log_weight_path"].mkdir(parents=True, exist_ok=True)


    # load hyperparameters from opt, and cfg file
    cfg = get_config(opt.cfg)
    sampling_mode = cfg["sampling_mode"]
    info["MODEL"] = {
        "device": opt.device,
        "MODEL_NAME": opt.modelname,
        "embedding_dim": cfg["embedding_dim"],
        "hoi_feature" : cfg["hoi_feature"],
        "depth": cfg["depth"],
        "dual_transformer_type": cfg["dual_transformer_type"],
        "num_heads": cfg["num_heads"],
        "box_encoder_type": cfg["box_encoder_type"],
        "semantic_type": cfg["semantic_type"],
        "blender_type": cfg["blender_type"],
        "moa_eps": cfg["moa_eps"],
        "train_feature_extractor": cfg["train_feature_extractor"],
        "head_cls_type": cfg["head_cls_type"],
        "dim_transformer_ffn": cfg["dim_transformer_ffn"],
        "sttran_enc_layer_num": cfg["sttran_enc_layer_num"],
        "sttran_dec_layer_num": cfg["sttran_dec_layer_num"],
        "sttran_sliding_window": cfg["sttran_sliding_window"],
        "semantic_masking_prob": cfg["semantic_masking_prob"],
        "gaze_usage": opt.gaze,
        "global_token": opt.global_token,
        "mlp_projection": cfg["mlp_projection"],
        "sinusoidal_encoding": cfg["sinusoidal_encoding"],
        "dropout": cfg["dropout"],
        "do_regression": False,
        "use_feature_extractor" : True,
        "use_semantic_extractor" : False,
        "use_sam_prompt_encoder" : False,
        "augmentation_semantic": cfg["augmentation_semantic"],
        "union_box": cfg["union_box"],
        "mlp_ratio": cfg["mlp_ratio"],
        "simple_semantics": cfg["simple_semantics"],
        "pos_embed_type": cfg["pos_embed_type"],
        "image_cls_type": cfg["image_cls_type"],
        "mainbranch": cfg["mainbranch"],
        "normalize_features": cfg["normalize_features"],
        "no_reduction": cfg["no_reduction"],
        "add_extension": cfg["add_extension"] if "add_extension" in cfg else False,
        "concat_extension": cfg["concat_extension"]  if "concat_extension" in cfg else False,

    }

    info["FINETUNE"] = {
        "backbone": cfg["backbone_model"] if "backbone_model" in cfg else "HOIBOT",
        "weights_path": weights_path / cfg["backbone_path_weight"].format(0, 0, "pt") if "backbone_path_weight" in cfg else "weights_path",
        "cfg": weights_path / cfg["backbone_path_weight"].format(0, 0, "yaml")  if "backbone_path_weight" in cfg else "backbone_path_weight",
        "future_num": cfg["future_num"] if sampling_mode == "anticipation" else 0,
        "baseline_model": cfg["baseline_model"] if "baseline_model" in cfg else "scratch",
        "anticipation_heads": cfg["anticipation_heads"] if "anticipation_heads" in cfg else cfg["future_num"],
        "all_heads_path": weights_path / cfg["backbone_path_weight"] if "backbone_path_weight" in cfg else "weights_path",
    }

    info["TRAINER"] = {
        "DATASET": {
            "future_num": cfg["future_num"] if sampling_mode == "anticipation" else 0,
            "future_type": cfg["future_type"] if sampling_mode == "anticipation" else "all",
            "future_ratio": cfg["future_ratio"] if sampling_mode == "anticipation" else 0,
            "subset_train_len": opt.subset_train,
            "subset_val_len": opt.subset_val,
            "min_clip_length": cfg["min_clip_length"],
            "max_clip_length": cfg["sttran_sliding_window"] if (sampling_mode == "window" or sampling_mode == "anticipation") else cfg["max_clip_length"],
            "max_human_num": cfg["max_human_num"],
            "img_size": opt.imgsz,
            "sampling_mode": sampling_mode,
            "train_ratio": cfg["train_ratio"],
            "val_ratio": 0,
            "batch_size": cfg["batch_size"],
            "batch_size_val": cfg["batch_size_val"],
            "subset_train_shuffle": True if opt.subset_train > 0 else False,
            "subset_val_shuffle": True if opt.subset_val > 0 else False,
            "AUGMENTATION": {
                "hflip_p": cfg["hflip_p"],
                "color_jitter": cfg["color_jitter"],
            },
        },
        "ENGINE": {
            "epochs_validation": cfg["epochs_validation"] + list(range(10,opt.epochs)),
            "separate_head" : cfg["separate_head"],
            "split_window": cfg["split_window"],
            "max_interaction_pairs":  cfg["split_window"],
            "save_period": opt.save_period,
            "max_epochs": opt.epochs,
            "interaction_conf_threshold": cfg["interaction_conf_threshold"],
            "eval_k": cfg["eval_k"],
            "random_seed": cfg["random_seed"],
            "rare_limit": cfg["rare_limit"],
            "iou_threshold": cfg["iou_threshold"],

        }
    }
    return info
