import os, sys
import numpy as np
import torch
import cv2
import copy
import shutil
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
from plot.utils.io import load_frames, write_video, read_images, convert_to_flow_image
from plot.utils.viz import draw_point_tracks
from plot.utils.geometry import meshgrid2d
from plot.utils.processing import erode_mask
from plot.utils.misc import find_target_frame_idx, limit_frames

import sys
sys.path.insert(0, "alltracker")
from nets.alltracker import Net


class ColorMap2d:
    def __init__(self, filename=None):
        self._colormap_file = filename 
        self._img = (plt.imread(self._colormap_file)*255).astype(np.uint8)
        
        self._height = self._img.shape[0]
        self._width = self._img.shape[1]

    def __call__(self, X):
        assert len(X.shape) == 2
        output = np.zeros((X.shape[0], 3), dtype=np.uint8)
        for i in range(X.shape[0]):
            x, y = X[i, :]
            xp = int((self._width-1) * x)
            yp = int((self._height-1) * y)
            xp = np.clip(xp, 0, self._width-1)
            yp = np.clip(yp, 0, self._height-1)
            output[i, :] = self._img[yp, xp]
        return output
    

def get_2d_colors(xys, H, W):
    N,D = xys.shape
    assert(D==2)
    bremm = ColorMap2d("./alltracker/utils/bremm.png")
    new_xys = copy.deepcopy(xys)
    new_xys[:,0] /= float(W-1)
    new_xys[:,1] /= float(H-1)
    colors = bremm(new_xys)
    return colors




class AllTracker:
    def __init__(self, device) -> None:
        self.device = device
        self.infer_iters = 4             # number of inference steps per forward
        self.subsample_rate = 4          # vis hyperparameter

        window_len = 16             # window length, S < video length, T
        self.model = Net(window_len)
        url = "https://huggingface.co/aharley/alltracker/resolve/main/alltracker.pth"
        state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu')
        self.model.load_state_dict(state_dict['model'], strict=True)
        self.model = self.model.to(self.device)
        self.model.eval()

    def preprocess_frames(self, frames):
        H, W = frames[0].shape[:2]
        # print(f"Original Image Size (H, W) -> ({H}, {W})")
        # target_width = 1024
        target_width = W
        scale = min(target_width / H, target_width / W)
        new_H, new_W = int(H*scale), int(W*scale)
        new_H, new_W = new_H//8 * 8, new_W//8 * 8   # make it divisible by 8
        # print(f"Model Image Size (H, W) -> ({new_H}, {new_W})")
        new_frames = [cv2.resize(frame, dsize=(new_W, new_H), interpolation=cv2.INTER_LINEAR) for frame in frames]

        frames_torch = [torch.from_numpy(frame).permute(2, 0, 1) for frame in new_frames]
        frames_torch = torch.stack(frames_torch, dim=0).unsqueeze(0).float()    # 1,T,C,H,W

        self.original_size = (H, W)
        self.new_size = (new_H, new_W)
        return frames_torch
    
    def visualize_tracks(self, frames, tracks, tracks_vis, tracks_confs):
        T, H, W, _ = tracks.shape
        # subsample to make the vis more readable
        tracks2 = tracks[:, ::self.subsample_rate, ::self.subsample_rate, :].reshape(T, -1, 2)    # T, N, 2
        tracks_vis2 = tracks_vis[:, ::self.subsample_rate, ::self.subsample_rate].reshape(T, -1)   # T, N
        tracks_confs2 = tracks_confs[:, ::self.subsample_rate, ::self.subsample_rate].reshape(T, -1)    # T, N
        
        colors = get_2d_colors(tracks2[0], H, W)

        # sort according to velocity, so that moving points are drawn last
        vels = tracks2[1:] - tracks2[:-1]   # T-1, N, 2
        vels = np.linalg.norm(vels, axis=-1).mean(axis=0)
        inds = np.argsort(vels)

        for t in range(T):
            draw_point_tracks(
                frames[t],
                tracks2[t],
                tracks_vis2[t],
                tracks_confs2[t],
                colors,
                radius=max(int(self.subsample_rate//2), 1),
                inds=None
            )


    def visualize_obj_tracks(self, frames, tracks, tracks_vis, tracks_confs, query_frame_idx):
        H, W = self.original_size
        colors = get_2d_colors(tracks[query_frame_idx], H, W)

        for t in range(len(tracks)):
            draw_point_tracks(frames[t], tracks[t], tracks_vis[t], tracks_confs[t], colors, radius=max(int(self.subsample_rate//2), 1), inds=None)
    
    @torch.inference_mode()
    def forward_model(self, inputs, query_frame):
        B, T, C, H, W = inputs.shape  
        grid_xy = meshgrid2d(W, H, self.device).unsqueeze(0).permute(0, 2, 1).reshape(1, 1, 2, H, W)

        torch.cuda.empty_cache()
        
        future_inputs = inputs[:, query_frame:]
        past_inputs = inputs[:, :query_frame+1]

        trajs_maps_e, visconf_maps_e = None, None

        if future_inputs.shape[1] >= 2:  # if more than 2 future frames
            flows_e, visconf_maps_e, _, _ = self.model(future_inputs, iters=self.infer_iters, sw=None, is_training=False)

            if flows_e.ndim < 5:
                flows_e = flows_e.unsqueeze(1)
                visconf_maps_e = visconf_maps_e.unsqueeze(1)
                forward_flow = flows_e + grid_xy    # B, T_half, 2, H, W
                trajs_maps_e = torch.cat([grid_xy, forward_flow], dim=1)
                ones = torch.ones([B, 1, 2, H, W], device=self.device)
                visconf_maps_e = torch.cat([ones, visconf_maps_e], dim=1)
            else:
                trajs_maps_e = flows_e + grid_xy    # B, T_half, 2, H, W
        else:   # if no future frames
            trajs_maps_e = grid_xy
            visconf_maps_e = torch.ones([B, 1, 2, H, W], device=self.device)
        
        if past_inputs.shape[1] >= 2:   # if more than 2 previous frames
            backward_flows_e, backward_vis_conf_maps_e, _, _ = self.model(past_inputs.flip([1]), iters=self.infer_iters, sw=None, is_training=False)

            if backward_flows_e.ndim < 5:   # only predict the flow in case of less than 2 frames
                backward_flows_e = backward_flows_e.unsqueeze(1)
                backward_vis_conf_maps_e = backward_vis_conf_maps_e.unsqueeze(1)
                backward_trajs_maps_e = backward_flows_e + grid_xy
            else:
                backward_trajs_maps_e = backward_flows_e + grid_xy
                backward_trajs_maps_e = backward_trajs_maps_e.flip([1])[:, :-1] # flip time and drop the overlapped frame
                backward_vis_conf_maps_e = backward_vis_conf_maps_e.flip([1])[:, :-1]

            trajs_maps_e = torch.cat([backward_trajs_maps_e, trajs_maps_e], dim=1)   # B, T, 2, H, W
            visconf_maps_e = torch.cat([backward_vis_conf_maps_e, visconf_maps_e], dim=1)  

        trajs_maps_e = torch.nn.functional.interpolate(trajs_maps_e[0], self.original_size, mode='nearest')
        visconf_maps_e = torch.nn.functional.interpolate(visconf_maps_e[0], self.original_size, mode='bilinear')

        tracks = trajs_maps_e.detach().cpu().numpy().transpose(0, 2, 3, 1)
        tracks_vis = visconf_maps_e[:, 0, :, :].detach().cpu().numpy()
        tracks_confs = visconf_maps_e[:, 1, :, :].detach().cpu().numpy()
        return tracks, tracks_vis, tracks_confs

    def get_object_tracks(self, tracks, tracks_vis, tracks_confs, obj_masks, labels):
        T, H, W, _ = tracks.shape
        grid_xy = meshgrid2d(W, H).reshape(H, W, 2).unsqueeze(0).numpy()
        flows = tracks - grid_xy
        all_obj_tracks = []
        all_obj_tracks_vis = []
        all_obj_tracks_confs = []
        for i in range(len(obj_masks)):
            obj_mask = obj_masks[i]
            class_name = labels[i]
            if "pedestrian" in class_name:
                obj_mask = erode_mask(obj_mask, 5, 5)
            else:
                obj_mask = erode_mask(obj_mask, 3, 3)

            obj_tracks = flows[:, obj_mask, :] + grid_xy[:, obj_mask, :]
            obj_tracks_vis = tracks_vis[:, obj_mask]
            obj_tracks_confs = tracks_confs[:, obj_mask]

            obj_tracks[..., 0] = np.clip(obj_tracks[..., 0], 0, self.original_size[1])
            obj_tracks[..., 1] = np.clip(obj_tracks[..., 1], 0, self.original_size[0])

            all_obj_tracks.append(obj_tracks.astype(np.float16))
            all_obj_tracks_vis.append(obj_tracks_vis.astype(np.float16))
            all_obj_tracks_confs.append(obj_tracks_confs.astype(np.float16))

        return all_obj_tracks, all_obj_tracks_vis, all_obj_tracks_confs



    def __call__(self, frames, query_idx):
        frames_torch = self.preprocess_frames(frames)
        tracks, tracks_vis, tracks_confs = self.forward_model(frames_torch, query_idx)
        return tracks, tracks_vis, tracks_confs





if __name__ == '__main__':
    root = Path("PATH_TO_DATA")

    device = torch.device("cuda:0")
    # scenes = sorted(list((root / "frames").glob("*")))
    with open(root / "ImageSets" / "val.txt") as f:
        scene_names = f.read().splitlines()
    
    save_dir = root / "alltracker"
    save_dir.mkdir(parents=True, exist_ok=True)

    det_dir = root / "gsam_frames"
    frame_limit = 10        # total frames = frame_limit * 2 + 1
    query_frame_idx = 20

    tracker = AllTracker(device)

    for scene_name in tqdm(scene_names):
        save_video_dir = save_dir / scene_name
        if save_video_dir.exists():
            # continue
            shutil.rmtree(save_video_dir)
        save_video_dir.mkdir(parents=True, exist_ok=True)

        _, frame_lists = load_frames(root / "frames" / scene_name)
        real_query_frame_idx = find_target_frame_idx(frame_lists, query_frame_idx)
        frame_lists = limit_frames(frame_lists, real_query_frame_idx, frame_limit)
        real_query_frame_idx = find_target_frame_idx(frame_lists, query_frame_idx)
        frames = read_images(frame_lists)

        obj_frames = copy.deepcopy(frames)

        all_tracks, all_tracks_vis, all_tracks_confs = tracker(frames, real_query_frame_idx)

        np.savez_compressed(
            save_video_dir / f"{scene_name}_dense",
            tracks=all_tracks,
            visibility=all_tracks_vis,
            # confs=all_tracks_confs
        )