import sys
import time
import argparse
import cv2
import gc
import numpy as np
import os
import os.path as osp
import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
from sam2.build_sam import build_sam2_video_predictor
from tqdm import tqdm
from einops import rearrange


def load_lasot_gt(gt_path):
    with open(gt_path, 'r') as f:
        gt = f.readlines()
    
    # bbox in first frame are prompts
    prompts = {}
    fid = 0
    for line in gt:
        x, y, w, h = map(int, line.split(','))
        prompts[fid] = ((x, y, x+w, y+h), 0)
        fid += 1

    return prompts

color = [
    (255, 0, 0),
]

parser = argparse.ArgumentParser(description="SAMITE Model Inference")
parser.add_argument("--data", default="LaSOT", type=str, choices=["LaSOT"], help="Dataset name")
parser.add_argument("--exp_name", default="samite", type=str, help="Experiment name")
parser.add_argument("--model_name", default="base_plus", type=str, choices=["tiny", "small", "base_plus", "large"], help="Size of SAM 2.1")
parser.add_argument("--max_num_indices", default=30, type=int, help="Number of candidate previous frames")
parser.add_argument("--selection_strategy", default="s2l_pos_feat_v2", type=str, choices=["none", "s2l_pos_feat_v2"], help="Frame selection strategy")
parser.add_argument("--bias_mode", default="kalman_pos", type=str, choices=["none", "kalman_pos",], help="Mode of Cross Attention Bias")
parser.add_argument("--bias_type", default="v3", choices=["v3", "none"], help="Type of Cross Attention Bias")
parser.add_argument("--use_prior_prompt", action="store_true", default=False)
parser.add_argument("--alpha", default=0.3, type=float, help="Fusion weight in frame selection.")
args = parser.parse_args()
print(args)

data = args.data
testing_set = f"data/{data}/testing_set.txt"
video_folder= f"data/{data}"
exp_name = args.exp_name
model_name = args.model_name

with open(testing_set, 'r') as f:
    test_videos = f.readlines()

checkpoint = f"sam2/checkpoints/sam2.1_hiera_{model_name}.pt"
if model_name == "base_plus":
    model_cfg = "configs/samite/samite_hiera_b+.yaml"
else:
    model_cfg = f"configs/samite/samite_hiera_{model_name[0]}.yaml"

pred_folder = f"results/{data}/{exp_name}/{exp_name}_{model_name}"

test_videos = sorted(test_videos)
for vid, video in enumerate(test_videos):
    cat_name = video.split('-')[0]
    cid_name = video.split('-')[1]
    video_basename = video.strip()
    frame_folder = osp.join(video_folder, cat_name, video.strip(), "img")
    args.frame_folder = frame_folder
    
    num_frames = len(os.listdir(osp.join(video_folder, cat_name, video.strip(), "img")))

    print(f"Running video [{vid+1}/{len(test_videos)}]: {video} with {num_frames} frames")

    height, width = cv2.imread(osp.join(frame_folder, "00000001.jpg")).shape[:2]

    predictor = build_sam2_video_predictor(model_cfg, checkpoint, device="cuda:0")

    predictions = []

    # Start processing frames
    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
        # ========================================
        # Init state
        # - async load video
        # - initialize variables
        # - extract features for <frame #0>
        # ========================================
        state = predictor.init_state(frame_folder, offload_video_to_cpu=True, offload_state_to_cpu=True, async_loading_frames=True)

        # ========================================
        # Add prompt
        # - add gt bbox to <frame #0>
        # ========================================
        # add bbox prompt
        prompts = load_lasot_gt(osp.join(video_folder, cat_name, video.strip(), "groundtruth.txt"))
        bbox, track_label = prompts[0]
        frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, box=bbox, frame_idx=0, obj_id=0)

        for frame_idx, object_ids, masks in predictor.propagate_in_video(state, args=args):
            mask_to_vis = {}
            bbox_to_vis = {}

            for obj_id, mask in zip(object_ids, masks):
                if obj_id == 0:
                    mask = mask[0].cpu().numpy()
                    mask = mask > 0.0
                    non_zero_indices = np.argwhere(mask)
                    if len(non_zero_indices) == 0:
                        bbox = [0, 0, 0, 0]
                    else:
                        y_min, x_min = non_zero_indices.min(axis=0).tolist()
                        y_max, x_max = non_zero_indices.max(axis=0).tolist()
                        bbox = [x_min, y_min, x_max-x_min, y_max-y_min]
                    bbox_to_vis[obj_id] = bbox
                    mask_to_vis[obj_id] = mask
                    
            predictions.append(bbox_to_vis)        
        
    os.makedirs(pred_folder, exist_ok=True)
    with open(osp.join(pred_folder, f'{video_basename}.txt'), 'w') as f:
        for pred in predictions:
            x, y, w, h = pred[0]
            f.write(f"{x},{y},{w},{h}\n")

    del predictor
    del state
    gc.collect()
    torch.clear_autocast_cache()
    torch.cuda.empty_cache()
    