#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys, json
import argparse
from pathlib import Path
import torch
import yaml
from configs.paths import project_path, annotations_dir, dataset_dir

def get_config(config_path):
    with Path(config_path).open("r") as config_file:
        configs = yaml.load(config_file, yaml.FullLoader)
    return configs

def parse_opt():
    parser = argparse.ArgumentParser()
    parser.add_argument("--modelname", type=str, default="HOIBOT", help="model_name to use [ HOIBOT, STTranGaze")
    parser.add_argument("--source", type=str, default="", help="path to source")
    parser.add_argument("--future", type=int, default=0, help="seconds in future")
    parser.add_argument("--cfg", type=str, default=project_path +"/demo/common/config/cfg", help="path to configs")
    parser.add_argument("--weights", type=str, default=project_path+"/weights", help="root folder for all pretrained weights")
    parser.add_argument("--imgsz", type=int, default=224, help="train, val image size (pixels)")
    parser.add_argument("--hoi-thres", type=float, default=0.3, help="threshold for HOI score")
    parser.add_argument("--out", type=str, default=project_path + "/output", help="output folder")
    parser.add_argument("--annotations", type=str, default=annotations_dir, help="output folder")
    parser.add_argument("--dataset_path", type=str, default=dataset_dir, help="dataset_path folder")
    parser.add_argument("--print", action="store_true", default=True, help="print HOIs")
    parser.add_argument("--show_detections", action="store_true", default=True, help="show_detections in Visualization")
    opt = parser.parse_args()
    return vars(opt)

def load_config(subst=None):
    opt = parse_opt()
    if subst is not None:
        for k, v in subst.items():
            opt[k]=v

    cfg = prepare_config(opt)
    return cfg



def prepare_config(opt):
    config = {}
    config["PATHS"] = {}
    source = Path(opt["source"])
    video_name = source.stem
    future = opt["future"]
    hoi_thres = opt["hoi_thres"]
    print_hois = opt["print"]
    show_detections = opt["show_detections"]

    annotations = opt["annotations"]
    dataset_path = opt["dataset_path"]

    cfg_path = Path(opt["cfg"])
    # path for all model weights
    weight_path = Path(opt["weights"])
    model_path = weight_path / "hoi4abot" / "dual_hydra"/ f"f{future}" / f"f{future}.pt"

    # output files
    out = opt["out"]
    output_folder = Path(out) / video_name
    if not output_folder.exists():
        output_folder.mkdir()
    hoi_file = output_folder / f"{video_name}_hoi.txt"
    result_file = output_folder / f"{video_name}_result.json"
    result_video_file = output_folder / f"{video_name}_result.mp4"



    # model params
    imgsz = opt["imgsz"]
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    cfg = get_config(str(cfg_path / f"best_f{future}.yaml"))
    sampling_mode = cfg["sampling_mode"]
    dim_gaze_heatmap = cfg["dim_gaze_heatmap"]  # 64x64 always, dont care
    dim_transformer_ffn = cfg["dim_transformer_ffn"]
    sttran_enc_layer_num = cfg["sttran_enc_layer_num"]
    sttran_dec_layer_num = cfg["sttran_dec_layer_num"]
    sttran_sliding_window = cfg["sttran_sliding_window"]
    separate_head = cfg["separate_head"]  # always separate, dont care
    loss_type = cfg["loss_type"]  # only focal, dont care
    mlp_projection = cfg["mlp_projection"]  # MLP in input embedding
    sinusoidal_encoding = cfg["sinusoidal_encoding"]  # sinusoidal positional encoding


    with Path(f"{annotations}/obj_categories.json").open("r") as f:
        object_classes = json.load(f)
    with Path(f"{annotations}/pred_categories.json").open("r") as f:
        interaction_classes = json.load(f)
    with Path(f"{annotations}/pred_split_categories.json").open("r") as f:
        temp_dict = json.load(f)
        spatial_class_idxes = temp_dict["spatial"]
        action_class_idxes = temp_dict["action"]

    objects_demo = get_config(project_path + "/demo/common/config/objects_demo.yaml")
    config["CLASSES"] = {
        "object_classes": object_classes,
        "interaction_classes": interaction_classes,
        "spatial_class_idxes": spatial_class_idxes,
        "action_class_idxes": action_class_idxes,
        "hoi2coco_mapper": objects_demo["hoi2coco_mapper"]
    }

    config["MODEL"] = {
        "device": device,
        "MODEL_NAME": "HOIBOT", #HYDRA, DUAL
        "embedding_dim": cfg["embedding_dim"],
        "hoi_feature" : cfg["hoi_feature"],
        "depth": cfg["depth"],
        "dual_transformer_type": cfg["dual_transformer_type"],
        "num_heads": cfg["num_heads"],
        "box_encoder_type": cfg["box_encoder_type"],
        "semantic_type": cfg["semantic_type"],
        "blender_type": cfg["blender_type"],
        "moa_eps": cfg["moa_eps"],
        "train_feature_extractor": cfg["train_feature_extractor"],
        "head_cls_type": cfg["head_cls_type"],
        "dim_transformer_ffn": cfg["dim_transformer_ffn"],
        "sttran_enc_layer_num": cfg["sttran_enc_layer_num"],
        "sttran_dec_layer_num": cfg["sttran_dec_layer_num"],
        "sttran_sliding_window": cfg["sttran_sliding_window"],
        "semantic_masking_prob": cfg["semantic_masking_prob"],
        "mlp_projection": cfg["mlp_projection"],
        "sinusoidal_encoding": cfg["sinusoidal_encoding"],
        "dropout": cfg["dropout"],
        "do_regression": False,
        "use_feature_extractor" : True,
        "use_semantic_extractor" : False,
        "use_sam_prompt_encoder" : False,
        "augmentation_semantic": cfg["augmentation_semantic"],
        "union_box": cfg["union_box"],
        "mlp_ratio": cfg["mlp_ratio"],
        "simple_semantics": cfg["simple_semantics"],
        "pos_embed_type": cfg["pos_embed_type"],
        "image_cls_type": cfg["image_cls_type"],
        "mainbranch": cfg["mainbranch"],
        "normalize_features": cfg["normalize_features"],
        "no_reduction": cfg["no_reduction"],
        "do_inference": True,
        "add_extension": cfg["add_extension"] if "add_extension" in cfg else False,
        "concat_extension": cfg["concat_extension"] if "concat_extension" in cfg else False,
    }

    future_heads = [0,1, 3, 5] if config["MODEL"]["MODEL_NAME"] == "HYDRA" else [0,3]
    cfg["backbone_path_weight"] = cfg["backbone_path_weight"].replace("dual", "hydra") if config["MODEL"]["MODEL_NAME"] == "HYDRA" else cfg["backbone_path_weight"]
    config["FINETUNE"] = {
        "backbone": cfg["backbone_model"] if "backbone_model" in cfg else "HOIBOT",
        "weights_path": weight_path / cfg["backbone_path_weight"].format( 0, "pt") if "backbone_path_weight" in cfg else "weights_path",
        "cfg": weight_path / cfg["backbone_path_weight"].format( 0, "yaml")  if "backbone_path_weight" in cfg else "backbone_path_weight",
        "future_num": cfg["future_num"] if sampling_mode == "anticipation" else 0,
        "baseline_model": cfg["baseline_model"] if "baseline_model" in cfg else "scratch",
        "anticipation_heads": cfg["anticipation_heads"] if "anticipation_heads" in cfg else [0,3],
        "all_heads_path": weight_path / cfg["backbone_path_weight"] if "backbone_path_weight" in cfg else "future_heads",
    }

    config["TRAINER"] = {}
    config["TRAINER"]["device"] = device

    config["TRAINER"]["DATASET"] = {}
    config["TRAINER"]["DATASET"]["max_clip_length"] = future
    config["TRAINER"]["DATASET"]["num_spatial_classes"] = len(spatial_class_idxes)
    config["TRAINER"]["DATASET"]["num_action_classes"] =  len(action_class_idxes)
    config["TRAINER"]["DATASET"]["object_classes"] = object_classes
    config["TRAINER"]["DATASET"]["img_size"] = opt["imgsz"]

    config["PATHS"] = {"model_path": model_path,
                       "annotations": annotations,
                       "dataset_path": dataset_path,
                       "result_video_file": result_video_file,
                       "cfg_path": cfg_path,
                       "hoi_file": hoi_file,
                       "video_name": video_name,
                       "result_file": result_file,
                       "source": source}
    config["VERBOSE"] = {
        "print_hois": print_hois,
        "show_detections": show_detections
    }
    config["PARAMS"] = {
        "cfg": cfg,
        "imgsz": imgsz,
        "gaze": False,
        "global_token": False,
        "sampling_mode": sampling_mode,
        "dim_gaze_heatmap": dim_gaze_heatmap,
        "device": device,
        "dim_transformer_ffn": dim_transformer_ffn,
        "sttran_enc_layer_num": sttran_enc_layer_num,
        "sttran_dec_layer_num": sttran_dec_layer_num,
        "sttran_sliding_window": sttran_sliding_window,
        "separate_head": separate_head,
        "loss_type": loss_type,
        "mlp_projection": mlp_projection,
        "sinusoidal_encoding": sinusoidal_encoding,
        "future": future,
        "hoi_thres": hoi_thres,
    }
    return config


if __name__ == "__main__":
    print(load_config())
