#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys

sys.path.insert(0, "../../modules/object_tracking/yolov5")

import argparse
import logging
import numpy as np
import random
import torch
from configs.paths import dataset_dir

from pathlib import Path

import pickle

from modules.sthoip_transformer.STTRAN_Wrapper import STTRAN_Wrapper
import glob
import os
import csv
from copy import deepcopy

from time import process_time_ns, time
from thop import profile
#from fvcore.nn import FlopCountAnalysis
import tqdm

from configs.cfg_to_info import cfg_to_info
from configs.paths import project_path,dataset_dir

global wandb_logger
global wandb_finished

from modules.object_tracking.transforms.transforms import STTranTransform, YOLOv5Transform
from torchvision.transforms.functional import resize


def convert_to_yolofromat(original_images):
    yolo_transform = YOLOv5Transform((384,640), 5)
    images = []
    for orig in original_images:
        images.append(yolo_transform(orig , do_resize=True))
    return np.concatenate(images, axis=0)

def compute_params(opt, cfg, device):
    cfg["MODEL"]["global_token"]=True

    model = STTRAN_Wrapper(cfg, modelname="STTRAN", gaze_usage="cross")

    model.to(device)
    model.info_model()

    ## Training
    # init some epoch logging
    # prepare metrics for wandb and pandas
    metrics = {}

    # set to evaluation mode
    model.eval()

    paths = glob.glob(opt.batch_path)
    paths.sort()

    header = ["num_interaction", "GFLOPs", "mean inference ms", "std inference ms", "params"]
    data = []
    measure_flops = False
    for path in paths:
        print(path)
        res = []
        _, num = os.path.split(path)
        num = num.split(".")[0]
        res.append(num)

        batch = torch.load(path)

        batch["yolo_frames"] = convert_to_yolofromat(batch["original_frames"])
        batch["frames"] = resize(batch["frames"], (640, 640)).to(device)
        img_size = batch["frames"].shape[2:]
        scale = np.min(img_size) / np.min(batch["yolo_frames"].shape[2:])
        batch["bboxes"][:, 1:] = batch["bboxes"][:, 1:] / scale
        bc = deepcopy(batch)

        if measure_flops:
            try:
                macs, params = profile(model, inputs=(bc,), verbose=False)
                print("Giga FLOPs from MACs", 2 * macs * 1e-9)
                print("Params", params)
                res.append(2 * macs * 1e-9)
            except:
                print("Error with ", (path))
                res.append(0)
                continue

        times = []
        try:
            for i in tqdm.tqdm(range(1000)):
                bc = deepcopy(batch)
                s = time()
                o = model(bc)
                times.append(time() - s)
        except:
            continue

        times = np.array(times)

        mean = np.mean(times)*1e3
        std = np.std(times)*1e3

        print(f"mean: {mean} ms, std: {std} ms")

        res.append(mean)
        res.append(std)

        if measure_flops:
            res.append(params)
        else:
            res.append(0)

        data.append(res)

        del batch

    with open('results.csv', 'w', newline='') as file:
        writer = csv.writer(file)

        writer.writerow(header)
        writer.writerows(data)



def parse_opt():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cfg", type=str, default= project_path +"/configs/ablation/train_hyp_f3_ours.yaml", help="path to hyperparameter configs")
    parser.add_argument("--weights", type=str, default=project_path +"/weights", help="root folder for all pretrained weights")
    parser.add_argument("--finetune-backbone", action="store_true", help="also finetune the ResNet backbone during training")
    parser.add_argument("--data", type=str, default=dataset_dir, help="dataset root path")
    parser.add_argument("--subset-train", type=int, default=-1, help="sub train dataset length")
    parser.add_argument("--subset-val", type=int, default=-1, help="sub val dataset length")

    parser.add_argument("--imgsz", "--img", "--img-size", type=int, default=224, help="train, val image size (pixels)")
    parser.add_argument("--epochs", type=int, default=40, help="number of epochs")
    parser.add_argument("--warmup", type=int, default=3, help="number of warmup epochs")
    parser.add_argument("--project", default = project_path + "/hoi/runs/HOI4ABOT", help="save to project/name")
    parser.add_argument("--name", default="exp", help="save to project/name")
    parser.add_argument("--save-period", type=int, default=1, help="Save checkpoint every x epochs (disabled if < 1)")
    parser.add_argument("--disable-wandb", action="store_true",  default=False, help="disable wandb logging")
    parser.add_argument("--gaze", type=str, default="no", help="how to use gaze features: no, concat, cross, cross_all")
    parser.add_argument("--global-token", action="store_true", help="Use global token, only for cross-attention mode")

    opt = parser.parse_args()
    return opt


def main(opt):
    opt.device = "cuda:0"
    opt.batch_path = dataset_dir + "/batches_filtered/*.pt"
    cfg = pickle.load(open("/VidHOI/cfg_pickle.yaml", "rb"))
    compute_params(opt, cfg, opt.device)


if __name__ == "__main__":
    torch.cuda.empty_cache()
    opt = parse_opt()
    main(opt)
