#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys

sys.path.insert(0, "../../modules/object_tracking/yolov5")

import argparse
import logging
import numpy as np
import random
import torch
from configs.paths import project_path

from pathlib import Path

from common.cfg.config_parser import get_config
from common.dataset.vidhoi_dataset import VidHOIDataset, dataset_collate_fn
from common.dataset.transforms import STTranTransform, ClipRandomHorizontalFlipping

from modules.hoi4abot.ModelWrapper import Model_Wrapper
import glob
import os
import csv

from time import process_time_ns, time
from thop import profile
#from fvcore.nn import FlopCountAnalysis
import tqdm

from configs.cfg_to_info import cfg_to_info
from configs.paths import project_path,dataset_dir

global wandb_logger
global wandb_finished

import copy


def train_gt_bbox(opt, device):
    global wandb_finished
    wandb_finished = False
    ## Init hyperparameters
    # path
    weights_path = Path(opt.weights)
    vidhoi_dataset_path = Path(opt.data)
    output_path = Path(opt.project)
    run_id = opt.name
    log_weight_suffix = "weights"
    log_weight_path = output_path / run_id / log_weight_suffix
    log_weight_path.mkdir(parents=True, exist_ok=True)
    # load hyperparameters from opt, and cfg file
    cfg = get_config(opt.cfg)
    # net and train hyperparameters
    img_size = opt.imgsz  # for yolov5 and feature backbone perprocessing
    sampling_mode = cfg["sampling_mode"]
    min_clip_length = cfg["min_clip_length"]
    max_clip_length = cfg["max_clip_length"]
    max_human_num = cfg["max_human_num"]
    sttran_sliding_window = cfg["sttran_sliding_window"]
    if sampling_mode == "window" or sampling_mode == "anticipation":
        max_clip_length = sttran_sliding_window  # max length for dataset
    if sampling_mode == "anticipation":
        future_num = cfg["future_num"]
    else:
        future_num = 0
    # set RNG seed for reproducibility
    random_seed = cfg["random_seed"]
    random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)

    vidhoi_test_dataset = VidHOIDataset(
        annotations_file=vidhoi_dataset_path / "annotations/val_frame_annots.json",
        frames_dir=vidhoi_dataset_path / "images",
        min_length=min_clip_length,
        max_length=max_clip_length,
        max_human_num=max_human_num,
        train_ratio=-1,
        subset_len=-1,
        subset_shuffle=False,
        transform=STTranTransform(img_size=img_size),
        annotation_mode=sampling_mode,
        logger=None,
        future_num=future_num,
        future_type="all",
        future_ratio=0.,
    )

    # transformer model
    info = cfg_to_info(opt)
    info["TRAINER"]["DATASET"].update({
        "object_classes": vidhoi_test_dataset.object_classes,
        "interaction_classes": vidhoi_test_dataset.interaction_classes,
        "spatial_class_idxes": vidhoi_test_dataset.spatial_class_idxes,
        "action_class_idxes": vidhoi_test_dataset.action_class_idxes,
        "num_object_classes": len(vidhoi_test_dataset.object_classes, ),
        "num_interaction_classes": len(vidhoi_test_dataset.interaction_classes),
        "num_action_classes": len(vidhoi_test_dataset.action_class_idxes),
        "num_spatial_classes": len(vidhoi_test_dataset.spatial_class_idxes),
    })
    info["MODEL"]["do_inference"] = True
    model = Model_Wrapper(info)

    model.to(device)

    ## Training
    # init some epoch logging
    # prepare metrics for wandb and pandas
    metrics = {}

    # set to evaluation mode
    model.eval()

    paths = glob.glob(opt.batch_path)

    header = ["num_interaction", "GFLOPs", "mean inference ms", "std inference ms", "params"]
    data = []

    for path in paths:
        res = []
        _, num = os.path.split(path)
        num = num.split(".")[0]
        res.append(num)

        batch = torch.load(path)
        #batch.pop("cls_tokens")
        #batch.pop("patch_tokens")
        batch["frames"] = batch["frames"].to(device)
        batch["bboxes"] = batch["bboxes"].to(device)
        batch["binary_masks"] = torch.ones(batch["bboxes"].shape[0], 224, 224, device=device)

        bc = copy.deepcopy(batch)
        macs, params = profile(model.model, inputs=(bc,), verbose=False)

        print("Giga FLOPs from MACs", 2*macs*1e-9)
        print("Params", params)

        res.append(2*macs*1e-9)

        times = []
        for i in tqdm.tqdm(range(1000)):
            bc = copy.deepcopy(batch)
            s = time()
            o = model(bc)
            times.append(time() - s)

        times = np.array(times)

        mean = np.mean(times)*1e3
        std = np.std(times)*1e3

        print(f"mean: {mean} ms, std: {std} ms")

        res.append(mean)
        res.append(std)
        res.append(params)

        data.append(res)

    with open('results.csv', 'w', newline='') as file:
        writer = csv.writer(file)

        writer.writerow(header)
        writer.writerows(data)



def parse_opt():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cfg", type=str, default= project_path +"/weights/hoi4abot/stacked/best_f0.yaml", help="path to hyperparameter configs")
    parser.add_argument("--weights", type=str, default=project_path +"/weights", help="root folder for all pretrained weights")
    parser.add_argument("--finetune-backbone", action="store_true", help="also finetune the ResNet backbone during training")
    parser.add_argument("--data", type=str, default=dataset_dir, help="dataset root path")
    parser.add_argument("--subset-train", type=int, default=-1, help="sub train dataset length")
    parser.add_argument("--subset-val", type=int, default=-1, help="sub val dataset length")

    parser.add_argument("--imgsz", "--img", "--img-size", type=int, default=224, help="train, val image size (pixels)")
    parser.add_argument("--epochs", type=int, default=40, help="number of epochs")
    parser.add_argument("--warmup", type=int, default=3, help="number of warmup epochs")
    parser.add_argument("--project", default = project_path + "/hoi/runs/HOI4ABOT", help="save to project/name")
    parser.add_argument("--name", default="exp", help="save to project/name")
    parser.add_argument("--save-period", type=int, default=1, help="Save checkpoint every x epochs (disabled if < 1)")
    parser.add_argument("--disable-wandb", action="store_true",  default=False, help="disable wandb logging")
    parser.add_argument("--gaze", type=str, default="no", help="how to use gaze features: no, concat, cross, cross_all")
    parser.add_argument("--global-token", action="store_true", help="Use global token, only for cross-attention mode")

    opt = parser.parse_args()
    return opt


def main(opt):
    opt.device = "cuda:0"
    opt.batch_path = dataset_dir + "/batches_filtered/*.pt"
    train_gt_bbox(opt, opt.device)


if __name__ == "__main__":
    torch.cuda.empty_cache()
    opt = parse_opt()
    opt.modelname = "HOIBOT"
    main(opt)
