import numpy as np
import matplotlib.pyplot as plt
import torch
from PIL import Image
from typing import Iterator, Tuple, Optional, List
from src.utils.video_io import resize_image_to_patch_size, unroll_video
from torch.nn.functional import interpolate
import torchvision.transforms as transforms
from math import ceil
from einops import rearrange
from sklearn.decomposition import PCA
from src.utils.video_io import load_image_from_url, load_array_from_url
from tqdm import tqdm
from multiprocessing.pool import ThreadPool, Pool
from src.utils.video_io import frame_stream_ffmpeg_python, frame_stream_decord

import itertools, math, functools, contextlib, time
from collections import deque
import time
import torchvision.transforms.functional as F
from torch.amp import autocast

torch.backends.cuda.enable_flash_sdp(True)       
torch.backends.cuda.enable_mem_efficient_sdp(True) # Triton kernel fallback
torch.backends.cuda.enable_math_sdp(False)  

def load_files_parallel(paths):
    imgs = []
    with ThreadPool(10) as pool:
        imgs = pool.map(plt.imread, paths)
    return imgs


def load_files_parallel_multiproc(paths):
    imgs = []
    with Pool(10) as pool:
        imgs = pool.map(plt.imread, paths)
    return imgs


def make_transform(size: Tuple[int, int]) -> transforms.Compose:
    IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
    IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
    interpolation_mode = transforms.InterpolationMode.BICUBIC

    return transforms.Compose(
        [
            transforms.Resize(
                size=size, interpolation=interpolation_mode, antialias=True
            ),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
        ]
    )

def _stream_batches(iterable, batch_size: int):
    it = iter(iterable)
    while True:
        batch = list(itertools.islice(it, batch_size))
        if not batch:
            break
        yield batch

class DinoV2FeatureExtractor:

    def __init__(
        self,
        device,
        *,
        model_name: str = "dinov2_vitl14_reg",
        maintain_aspect_ratio: bool = True,
    ):
        super().__init__()
        self.device = device
        self.maintain_aspect_ratio = maintain_aspect_ratio
        self.model = (
            torch.hub.load("facebookresearch/dinov2", model_name).to(device).eval()
        )
        self.patch_size = self.model.patch_size

    def get_feature(self, image: torch.Tensor):
        """
        :param image: Image tensor of shape (C, H, W), values are in range 0...1
        """
        image = self._prepare_image(image)
        with torch.inference_mode():
            out = self.model(image)
        return out[0]

    def get_feature_grid(self, image: torch.Tensor):
        """
        :param image: Image tensor of shape (C, H, W), values are in range 0...1
        """
        image = self._prepare_image(image)
        new_h = image.shape[2]
        with torch.inference_mode():
            tokens = self.model.get_intermediate_layers(image)[0]
        grid = rearrange(tokens, "b (h w) d -> b h w d", h=int(new_h / self.patch_size))
        return grid[0]

    def extract_features_from_filelist(
        self,
        image_fnames: List[str],
        *,
        batch_size=8,
        output_dtype: str = "float16",
        use_tqdm: bool = True,
        label=None,
    ) -> np.ndarray:
        feat, _ = self._extract_features_from_filelist(
            image_fnames=image_fnames,
            batch_size=batch_size,
            return_grid=False,
            return_features=True,
            output_dtype=output_dtype,
            use_tqdm=use_tqdm,
            label=label,
        )
        return feat

    def extract_feature_grids_from_filelist(
        self,
        image_fnames: List[str],
        *,
        batch_size=8,
        output_dtype: str = "float16",
        use_tqdm: bool = True,
        label=None,
    ) -> np.ndarray:
        _, grid = self._extract_features_from_filelist(
            image_fnames=image_fnames,
            batch_size=batch_size,
            return_grid=True,
            return_features=False,
            output_dtype=output_dtype,
            use_tqdm=use_tqdm,
            label=label,
        )
        return grid

    def extract_features_and_feature_grids_from_filelist(
        self,
        image_fnames: List[str],
        *,
        batch_size=8,
        output_dtype: str = "float16",
        use_tqdm: bool = True,
        label=None,
    ) -> np.ndarray:
        return self._extract_features_from_filelist(
            image_fnames=image_fnames,
            batch_size=batch_size,
            return_grid=True,
            return_features=True,
            output_dtype=output_dtype,
            use_tqdm=use_tqdm,
            label=label,
        )

    def _extract_features_from_filelist(
        self,
        image_fnames: List[str],
        return_grid: bool,
        return_features: bool,
        *,
        batch_size=8,
        output_dtype: str = "float16",
        use_tqdm: bool = True,
        use_multiprocessing: bool = False,
        load_all_frames_at_once: bool = False,
        label=None,
    ) -> np.ndarray:
        """
        :param image_fnames: List of image filenames
        """
        if load_all_frames_at_once:
            return self._extract_features_from_filelist_all_at_once(
                image_fnames,
                return_grid,
                return_features,
                batch_size=batch_size,
                output_dtype=output_dtype,
                use_tqdm=use_tqdm,
                use_multiprocessing=use_multiprocessing,
                label=label,
            )
        if use_multiprocessing:
            load_files_parallel_fn = load_files_parallel_multiproc
        else:
            load_files_parallel_fn = load_files_parallel
        features = []
        features_grid = []
        expected_size = None
        new_h = None
        steps = list(range(0, len(image_fnames), batch_size))
        for i in (pbar := tqdm(steps, total=len(steps), disable=not use_tqdm)):
            if label is not None:
                pbar.set_description(f"{label}")
            start = i
            end = min(i + batch_size, len(image_fnames))
            if end > start:
                img_list = []
                for img in load_files_parallel_fn(image_fnames[start:end]):
                    if expected_size is None:
                        expected_size = img.shape[:2]
                    else:
                        assert (
                            img.shape[:2] == expected_size
                        ), f"Expected all images to have the same size, but got {img.shape[:2]} and {expected_size}"
                    if np.max(img) > 1.1:
                        img = img / 255.0
                    img_tn = (
                        torch.from_numpy(img).permute(2, 0, 1).float().to(self.device)
                    )
                    img_tn = self._prepare_image(img_tn)
                    if new_h is None:
                        new_h = img_tn.shape[2]
                    img_list.append(img_tn)
                img_list = torch.cat(img_list, dim=0)

                with torch.inference_mode():
                    if return_grid:
                        tokens = self.model.get_intermediate_layers(img_list)[0]
                        grid = rearrange(
                            tokens,
                            "b (h w) d -> b h w d",
                            h=int(new_h / self.patch_size),
                        )
                        features_grid.append(grid.cpu().numpy().astype(output_dtype))
                    if return_features:
                        tokens = self.model(img_list)
                        features.append(tokens.cpu().numpy().astype(output_dtype))

        features_out = None
        grid_out = None
        if return_features:
            if len(features) == 0:
                raise ValueError("No images were loaded", image_fnames)
            features_out = np.concatenate(features, axis=0)
        if return_grid:
            if len(features_grid) == 0:
                raise ValueError("No images were loaded", image_fnames)
            grid_out = np.concatenate(features_grid, axis=0)

        return features_out, grid_out

    def _extract_features_from_filelist_all_at_once(
        self,
        image_fnames: List[str],
        return_grid: bool,
        return_features: bool,
        *,
        batch_size=8,
        output_dtype: str = "float16",
        use_tqdm: bool = True,
        use_multiprocessing: bool = False,
        label=None,
    ) -> np.ndarray:

        all_images_tn = self._load_and_prepare_all_images(
            image_fnames=image_fnames, use_multiprocessing=use_multiprocessing
        )
        new_h = all_images_tn.shape[2]
        features = []
        features_grid = []

        steps = list(range(0, len(image_fnames), batch_size))
        with torch.inference_mode():
            for i in (pbar := tqdm(steps, total=len(steps), disable=not use_tqdm)):
                if label is not None:
                    pbar.set_description(f"{label}")
                start = i
                end = min(i + batch_size, len(image_fnames))
                if end > start:
                    B = all_images_tn[start:end]
                    if return_grid:
                        tokens = self.model.get_intermediate_layers(B)[0]
                        grid = rearrange(
                            tokens,
                            "b (h w) d -> b h w d",
                            h=int(new_h / self.patch_size),
                        )
                        features_grid.append(grid.cpu().numpy().astype(output_dtype))
                    if return_features:
                        tokens = self.model(B)
                        features.append(tokens.cpu().numpy().astype(output_dtype))

        features_out = None
        grid_out = None
        if return_features:
            if len(features) == 0:
                raise ValueError("No images were loaded", image_fnames)
            features_out = np.concatenate(features, axis=0)
        if return_grid:
            if len(features_grid) == 0:
                raise ValueError("No images were loaded", image_fnames)
            grid_out = np.concatenate(features_grid, axis=0)

        return features_out, grid_out

    def _load_and_prepare_all_images(
        self, image_fnames: List[str], use_multiprocessing: bool = False
    ) -> torch.Tensor:
        if use_multiprocessing:
            load_files_parallel_fn = load_files_parallel_multiproc
        else:
            load_files_parallel_fn = load_files_parallel

        images = np.array(load_files_parallel_fn(image_fnames))
        if np.max(images) > 1.1:
            images = images / 255.0
        images_tn = (
            rearrange(torch.from_numpy(images), "b h w d -> b d h w")
            .float()
            .to(self.device)
        )
        return self._prepare_images(images_tn)

    def get_feature_grids(self, images: torch.Tensor):
        """
        :param images: {B x 3 x H x W}
        """
        images = self._prepare_images(images)
        new_h = images.shape[2]
        with torch.inference_mode():
            tokens = self.model.get_intermediate_layers(images)[0]
            return rearrange(
                tokens,
                "b (h w) d -> b h w d",
                h=int(new_h / self.patch_size),
            )

    def _prepare_images(self, images: torch.Tensor):
        if not torch.is_tensor(images):
            images = torch.from_numpy(images).float()
        assert torch.is_tensor(
            images
        ), f"Expected image to be a tensor, got {type(images)}"
        assert (
            images.dim() == 4
        ), f"Expected image to have 4 dimensions, got {images.dim()}"
        assert (
            images.shape[1] == 3
        ), f"Expected image to have 3 channels, got {images.shape[2]}"
        assert (
            torch.min(images) >= 0 and torch.max(images) <= 1
        ), f"Expected image to have values in range 0...1 but got {torch.min(images).item()}...{torch.max(images).item()}"
        IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
        IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

        _, _, h, w = images.shape
        new_w = ((w + self.patch_size - 1) // self.patch_size) * self.patch_size
        new_h = ((h + self.patch_size - 1) // self.patch_size) * self.patch_size

        if self.maintain_aspect_ratio:
            scale = max(new_w / w, new_h / h)
            resize_transform = transforms.Resize(
                size=(int(ceil(h * scale)), int(ceil(w * scale))),
                interpolation=transforms.InterpolationMode.BICUBIC,
                antialias=True,
            )
        else:
            resize_transform = transforms.Resize(
                size=(new_h, new_w),
                interpolation=transforms.InterpolationMode.BICUBIC,
                antialias=True,
            )

        transform = transforms.Compose(
            [
                resize_transform,
                transforms.Normalize(
                    mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
                ),
            ]
        )

        images = transform(images).to(self.device)  # b x c x h x w
        if self.maintain_aspect_ratio:
            cut_w = images.shape[3] - new_w
            cut_h = images.shape[2] - new_h
            start_x = cut_w // 2
            start_y = cut_h // 2
            images = images[:, :, start_y : start_y + new_h, start_x : start_x + new_w]
        return images

    def _prepare_image(self, image: torch.Tensor):
        if not torch.is_tensor(image):
            image = torch.from_numpy(image).float()
        assert torch.is_tensor(
            image
        ), f"Expected image to be a tensor, got {type(image)}"
        assert (
            image.dim() == 3
        ), f"Expected image to have 3 dimensions, got {image.dim()}"
        assert (
            image.shape[0] == 3
        ), f"Expected image to have 3 channels, got {image.shape[0]}"
        assert (
            torch.min(image) >= 0 and torch.max(image) <= 1
        ), f"Expected image to have values in range 0...1 but got {torch.min(image).item()}...{torch.max(image).item()}"
        IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
        IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

        # h, w, _ = image.shape
        _, h, w = image.shape
        new_w = ((w + self.patch_size - 1) // self.patch_size) * self.patch_size
        new_h = ((h + self.patch_size - 1) // self.patch_size) * self.patch_size

        if self.maintain_aspect_ratio:
            scale = max(new_w / w, new_h / h)
            resize_transform = transforms.Resize(
                size=(int(ceil(h * scale)), int(ceil(w * scale))),
                interpolation=transforms.InterpolationMode.BICUBIC,
                antialias=True,
            )
        else:
            resize_transform = transforms.Resize(
                size=(new_h, new_w),
                interpolation=transforms.InterpolationMode.BICUBIC,
                antialias=True,
            )

        transform = transforms.Compose(
            [
                resize_transform,
                transforms.Normalize(
                    mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
                ),
            ]
        )

        image = transform(image).unsqueeze(0).to(self.device)  # b x c x h x w
        if self.maintain_aspect_ratio:
            cut_w = image.shape[3] - new_w
            cut_h = image.shape[2] - new_h
            start_x = cut_w // 2
            start_y = cut_h // 2
            image = image[:, :, start_y : start_y + new_h, start_x : start_x + new_w]
        return image


# = = = VISUALIZATION = = =
from scipy.ndimage import binary_closing, binary_opening


def make_foreground_mask(
    tokens,
    grid_size: Tuple[int, int],
    device: torch.device,
    background_threshold: float = 0.0,
    apply_opening: bool = True,
    apply_closing: bool = True,
):
    STANDARD_ARRAY_URL = "https://dl.fbaipublicfiles.com/dinov2/arrays/standard.npy"
    standard_array = load_array_from_url(STANDARD_ARRAY_URL)

    projection = tokens @ torch.from_numpy(standard_array).float().to(device)
    mask = projection > background_threshold
    mask = mask.reshape(*grid_size)
    if apply_opening:
        mask = binary_opening(mask)
    if apply_closing:
        mask = binary_closing(mask)
    return mask.flatten()


def render_patch_pca(
    tokens,  # h w c
    image_size,
    device,
    *,
    background_threshold: float = 0.05,
    apply_opening: bool = False,
    apply_closing: bool = False,
    do_masking: bool = True,
) -> Image:
    grid_size = tokens.shape[:2]
    tokens = rearrange(tokens, "h w c -> (h w) c").float()
    mask = make_foreground_mask(
        tokens,
        grid_size,
        background_threshold=background_threshold,
        apply_opening=apply_opening,
        apply_closing=apply_closing,
        device=device,
    )

    pca = PCA(n_components=3)
    if do_masking:
        pca.fit(tokens[mask].cpu().numpy())
    else:
        pca.fit(tokens.cpu().numpy())
    projected_tokens = pca.transform(tokens.cpu().numpy())

    t = torch.tensor(projected_tokens)
    t_min = t.min(dim=0, keepdim=True).values
    t_max = t.max(dim=0, keepdim=True).values
    normalized_t = (t - t_min) / (t_max - t_min)
    normalized_t = normalized_t.cpu()

    mask = mask.cpu().numpy()

    array = (normalized_t * 255).byte().numpy()
    if do_masking:
        array[~mask] = 0
    array = array.reshape(*grid_size, 3)

    h, w = image_size
    return Image.fromarray(array).resize((w, h), 0)


def render_patch_pca_batched(
    tokens,  # t h w c
    image_size,
    device,
    *,
    background_threshold: float = 0.05,
    apply_opening: bool = False,
    apply_closing: bool = False,
    do_masking: bool = True,
) -> Image:
    assert len(tokens.shape) == 4
    grid_size = tokens.shape[:3]
    # batched_grid
    if not torch.is_tensor(tokens):
        tokens = torch.from_numpy(tokens)
    tokens = rearrange(tokens, "b h w c -> (b h w) c").float()
    mask = make_foreground_mask(
        tokens,
        grid_size,
        background_threshold=background_threshold,
        apply_opening=apply_opening,
        apply_closing=apply_closing,
        device=device,
    )

    pca = PCA(n_components=3)
    if do_masking:
        pca.fit(tokens[mask].cpu().numpy())
    else:
        pca.fit(tokens.cpu().numpy())
    projected_tokens = pca.transform(tokens.cpu().numpy())

    t = torch.tensor(projected_tokens)
    t_min = t.min(dim=0, keepdim=True).values
    t_max = t.max(dim=0, keepdim=True).values
    normalized_t = (t - t_min) / (t_max - t_min)
    normalized_t = normalized_t.cpu()

    mask = mask.cpu().numpy()

    array = (normalized_t * 255).byte().numpy()
    if do_masking:
        array[~mask] = 0
    array = array.reshape(*grid_size, 3)

    h, w = image_size
    # return Image.fromarray(array).resize((w, h), 0)

    vid_resized = interpolate(
        torch.from_numpy(rearrange(array, "t h w c -> t c h w")).to(device),
        size=(w, h),
        mode="bilinear",
        align_corners=False,
    )
    array = rearrange(vid_resized.cpu().numpy(), "t c h w -> t h w c")
    return array

class DinoV2FeatureExtractorStream(DinoV2FeatureExtractor):

    def __init__(
        self, 
        device: str,
        model_name: str = "dinov2_vits14_reg", 
        maintain_aspect_ratio: bool = True, 
        compile: bool = False, 
        **kw
    ):
        super().__init__(
            device,
            model_name=model_name,
            maintain_aspect_ratio=maintain_aspect_ratio,
            **kw
        )
        self.model.to(self.device, memory_format=torch.channels_last).eval()
        self.mean = torch.tensor([0.485,0.456,0.406], device=self.device)[:,None,None]
        self.std  = torch.tensor([0.229,0.224,0.225], device=self.device)[:,None,None]
        if compile:
            try:
                self.model = torch.compile(
                    self.model, mode="max-autotune",
                    fullgraph=True, dynamic=True,
                )
            except Exception as e:
                print("torch.compile failed: fallback ->", e)


    def _quick_preprocess(
        self,
        imgs: torch.Tensor,          # (B,3,H,W) 0-1 float32
        patch_size: int, 
        maintain_aspect_ratio: bool = True
    ) -> torch.Tensor:
        # NOTE: make sure bilinear and antialias=false work properly
        resize_kwargs = {
        "interpolation": F.InterpolationMode.BILINEAR,
        "antialias": False
        }
        B, C, h, w = imgs.shape
        assert C == 3, "expects RGB"

        new_w = math.ceil(w / patch_size) * patch_size
        new_h = math.ceil(h / patch_size) * patch_size

        if maintain_aspect_ratio:
            scale = max(new_h / h, new_w / w)
            resized_h, resized_w = math.ceil(h * scale), math.ceil(w * scale)
            imgs = F.resize(
                imgs, 
                [resized_h, resized_w],
                **resize_kwargs
            )
            top  = (resized_h - new_h) // 2
            left = (resized_w - new_w) // 2
            imgs = imgs[:, :, top: top + new_h, left: left + new_w]
        else:
            imgs = F.resize(
                imgs, 
                [new_h, new_w],
                **resize_kwargs
            )
        imgs.sub_(self.mean).div_(self.std)
        return imgs.contiguous(memory_format=torch.channels_last)


    def stream_feature_generator(
        self,
        frame_iter: itertools.count,
        *,
        batch_size: int = 1,
        amp: bool = True,
        resize_grid_factor: int = 2, 
        patch_multiple: int | None = None,
        max_queue: int = 32 
    ):
        patch = self.patch_size
        patch_multiple = patch_multiple or patch
        q: deque[torch.Tensor] = deque()
        main_stream = torch.cuda.current_stream()
        aux_stream  = torch.cuda.Stream(device='cuda:4')

        def flush():
            nonlocal q
            if not q:
                return
                # batch = torch.stack(list(q), dim=0).to(self.device)     # (B,3,H,W)
                # q.clear()
            batch = torch.stack(list(q), dim=0, out=None)  # still CPU
            q.clear()
            batch = batch.pin_memory()                     # page-locked
            batch = batch.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
            
            with torch.cuda.stream(aux_stream):
                batch = self._quick_preprocess(
                    batch, 
                    patch_size=self.patch_size,
                    maintain_aspect_ratio=True
                )
                new_h = batch.shape[2]

                with autocast(device_type="cuda", enabled=amp, dtype=torch.float32):
                    tokens = self.model.get_intermediate_layers(batch, n=1)[0]
                grid = rearrange(tokens, "b (h w) d -> b d h w", h=int(new_h / patch))
                if resize_grid_factor != 1:
                    B, D, H, W = grid.shape
                    new_h = int(round(H / resize_grid_factor))
                    new_w = int(round(W / resize_grid_factor))
                    
                    grid = interpolate(
                        grid,
                        size=(new_h, new_w),
                        mode="bilinear",
                        align_corners=False,
                    ).permute(0, 2, 3, 1).to(torch.float16).contiguous() 
            main_stream.wait_stream(aux_stream)
            for g in grid:
                yield g

        for frame in frame_iter:
            q.append(frame)
            if len(q) >= batch_size or len(q) >= max_queue:
                yield from flush()
        yield from flush()


if __name__ == "__main__":
    device = "cuda:1"
    extractor = DinoV2FeatureExtractorStream(device, compile=True)

    frame_iter = frame_stream_decord(
        "/example/video.mp4",
        scale = (640, 360),
        target_fps = 30,
        gpu_id = 1
    )
    total_start = time.perf_counter()
    per_frame_times = []
    count_frames = 0
    for frame in frame_iter:
        t0 = time.perf_counter()
        for grid in extractor.stream_feature_generator(
                [frame],
                batch_size = 1,
                amp=True
            ):
            pass
        if count_frames >=1:
            per_frame_times.append(time.perf_counter() - t0)
        count_frames += 1

    total_elapsed = time.perf_counter() - total_start

    if count_frames:
        avg = sum(per_frame_times) / (count_frames-1)
        print(f"processed {count_frames} frames in {total_elapsed:.2f} s "
              f"(avg {avg*1000:.2f} ms / frame)")
    else:
        print("No frames processed.")