import sys
import argparse
import cv2
import gc
import numpy as np
import os
import os.path as osp
import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
import glob
from sam2.build_sam import build_sam2_video_predictor
from tqdm import tqdm
from einops import rearrange


def load_gt(gt_path):
    with open(gt_path, 'r') as f:
        gt = f.readlines()
    
    # bbox in first frame are prompts
    prompts = {}
    fid = 0
    for line in gt:
        x, y, w, h = map(int, line.split(','))
        prompts[fid] = ((x, y, x+w, y+h), 0)
        fid += 1

    return prompts

color = [
    (255, 0, 0),
]

parser = argparse.ArgumentParser(description="SAMITE Model Inference")
parser.add_argument("--data", default="TrackingNet", type=str, choices=["TrackingNet"], help="Dataset name")
parser.add_argument("--exp_name", default="samite", type=str, help="Experiment name")
parser.add_argument("--model_name", default="base_plus", type=str, choices=["tiny", "small", "base_plus", "large"], help="Size of SAM 2.1")
parser.add_argument("--max_num_indices", default=30, type=int, help="Number of candidate previous frames")
parser.add_argument("--selection_strategy", default="s2l_pos_feat_v2", type=str, choices=["none", "s2l_pos_feat_v2"], help="Frame selection strategy")
parser.add_argument("--bias_mode", default="kalman_pos", type=str, choices=["none", "kalman_pos",], help="Mode of Cross Attention Bias")
parser.add_argument("--bias_type", default="v3", choices=["v3", "none"], help="Type of Cross Attention Bias")
parser.add_argument("--use_prior_prompt", action="store_true", default=False)
parser.add_argument("--alpha", default=0.3, type=float, help="Fusion weight in frame selection.")
args = parser.parse_args()
print(args)

data = args.data
video_folder = f"data/{data}/TEST/zips"
exp_name = args.exp_name
model_name = args.model_name

checkpoint = f"sam2/checkpoints/sam2.1_hiera_{model_name}.pt"
if model_name == "base_plus":
    model_cfg = "configs/samite/samite_hiera_b+.yaml"
else:
    model_cfg = f"configs/samite/samite_hiera_{model_name[0]}.yaml"

pred_folder = f"results/{data}/{exp_name}/{exp_name}_{model_name}"

# ========================================
# Glob TrackingNet's testing videos
# ========================================
test_videos = sorted(glob.glob(f"{video_folder}/*"))
for vid, video in enumerate(test_videos):
    video_basename = video.split("/")[-1]
    frame_folder = video
    args.frame_folder = frame_folder
    num_frames = len(os.listdir(frame_folder))

    print(f"Running video [{vid+1}/{len(test_videos)}]: {video} with {num_frames} frames")

    height, width = cv2.imread(osp.join(frame_folder, "0.jpg")).shape[:2]

    predictor = build_sam2_video_predictor(model_cfg, checkpoint, device="cuda:0")
    predictions = []
    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
        # ========================================
        # Init state
        # - async load video
        # - initialize variables
        # - extract features for <frame #0>
        # ========================================
        state = predictor.init_state(frame_folder, offload_video_to_cpu=True, offload_state_to_cpu=True, async_loading_frames=True)

        # ========================================
        # Add prompt
        # - add gt bbox to <frame #0>
        # ========================================
        # add bbox prompt
        prompts = load_gt(video.replace("zips", "anno") + ".txt")
        bbox, track_label = prompts[0]
        frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, box=bbox, frame_idx=0, obj_id=0)

        for frame_idx, object_ids, masks in predictor.propagate_in_video(state, args=args):
            mask_to_vis = {}
            bbox_to_vis = {}

            for obj_id, mask in zip(object_ids, masks):
                if obj_id == 0:
                    mask = mask[0].cpu().numpy()
                    mask = mask > 0.0
                    non_zero_indices = np.argwhere(mask)
                    if len(non_zero_indices) == 0:
                        bbox = [0, 0, 0, 0]
                    else:
                        y_min, x_min = non_zero_indices.min(axis=0).tolist()
                        y_max, x_max = non_zero_indices.max(axis=0).tolist()
                        bbox = [x_min, y_min, x_max-x_min, y_max-y_min]
                    bbox_to_vis[obj_id] = bbox
                    mask_to_vis[obj_id] = mask
            predictions.append(bbox_to_vis)        
        
    os.makedirs(pred_folder, exist_ok=True)
    with open(osp.join(pred_folder, f'{video_basename}.txt'), 'w') as f:
        for pred in predictions:
            x, y, w, h = pred[0]
            f.write(f"{x},{y},{w},{h}\n")

    del predictor
    del state
    gc.collect()
    torch.clear_autocast_cache()
    torch.cuda.empty_cache()
