import re
import os
import copy
import torch
import argparse
import cv2

import numpy as np
import imageio.v3 as iio

from PIL import Image
from cotracker.utils.visualizer import Visualizer

from matplotlib import pyplot as plt
from tqdm import tqdm

def get_fps(video_path):
    """Extract FPS using OpenCV."""
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    cap.release()
    return fps

def main():
    parser = argparse.ArgumentParser(description="Process video tracking with CoTracker.")

    parser.add_argument(
        "--url",
        type=str,
        required=True,
        help="Path to the input video file. e.g., /home/video_cauasl/video_data/spring_mass/synthetic_data_damped/0_(93.0,64.0).mp4"
    )

    parser.add_argument(
        "--tracking_point_x",
        type=float,
        nargs='+',
        required=True,
        help="93.0 from file name 0_(93.0,64.0).mp4"
    )

    parser.add_argument(
        "--tracking_point_y",
        type=float,
        nargs='+',
        required=True,
        help="64.0 from file name 0_(93.0,64.0).mp4"
    )

    parser.add_argument(
        "--save_dir",
        type=str,
        required=True,
        help="Directory to save the processed tracking output. e.g., /home/video_cauasl/video_data/spring_mass/synthetic_data_damped_co_tracker"
    )

    parser.add_argument(
        "--grid_size",
        type=int,
        default=10,
        help="Grid size for CoTracker."
    )

    parser.add_argument(
        "--device",
        type=str,
        default='cuda' if torch.cuda.is_available() else 'cpu',
        help="Device to run the model on ('cuda' or 'cpu')."
    )

    args = parser.parse_args()

    url = args.url
    save_dir = args.save_dir
    device = args.device
    grid_size = args.grid_size
    read_filename = os.path.splitext(os.path.basename(url))[0]
    # print(read_filename)

    frames = iio.imread(url, plugin="FFMPEG")  # plugin="pyav"
    video = torch.tensor(frames).permute(0, 3, 1, 2)[None].float().to(device)  # B T C H W

    fps = get_fps(url)
    print(fps)
    total_frames = frames.shape[0]
    print(total_frames)
    duration = total_frames / fps

    cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker3_online").to(device)
    # cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline").to(device)

    # Run Online CoTracker, the same model with a different API:
    # Initialize online processing
    # cotracker(video_chunk=video, is_first_step=True, grid_size=grid_size)
    start_points = []
    for x0, y0 in zip(args.tracking_point_x, args.tracking_point_y):
        start_points.append([0., x0, y0])

    queries = torch.tensor(start_points)
    if torch.cuda.is_available():
        queries = queries.cuda()

    cotracker(video, is_first_step=True, grid_size=args.grid_size) # queries=queries[None]


    # Process the video
    for ind in tqdm(range(0, video.shape[1] - cotracker.step, cotracker.step)):
        pred_tracks, pred_visibility = cotracker(
            video_chunk=video[:, ind : ind + cotracker.step * 2]
        )  # B T N 2,  B T N 1

    # print(pred_tracks.shape, pred_visibility.shape)

    vis = Visualizer(save_dir=save_dir, pad_value=0, linewidth=3, fps=60)
    vis.visualize(video, pred_tracks, pred_visibility, filename=read_filename)

    print(pred_tracks[0].shape)
    reshaped_pred_tracks = pred_tracks[0]
    reshaped_pred_tracks_np = reshaped_pred_tracks.cpu().numpy()
    reshaped_pred_tracks_np = reshaped_pred_tracks_np.transpose(1, 0, 2)
    print(reshaped_pred_tracks_np.shape)

    time_dim = torch.linspace(0, duration, steps=reshaped_pred_tracks.shape[0]).unsqueeze(1).to(reshaped_pred_tracks.device)
    time_dim_np = time_dim.cpu().numpy()

    print(reshaped_pred_tracks_np.shape, time_dim.shape)
    np.save(f"{save_dir}/{read_filename}.npy", reshaped_pred_tracks_np)
    np.save(f"{save_dir}/{read_filename}_time.npy", time_dim_np)

if __name__ == "__main__":
    main()