import time
from pathlib import Path
import random
from typing import Dict, Union, Any, Optional
import contextlib

import math
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
try:
    import deepspeed
    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
except:
    pass 
from transformers import get_scheduler as _get_scheduler
from transformers.optimization import SchedulerType
from omegaconf import OmegaConf, open_dict, DictConfig
import hydra
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import random
from einops import rearrange, repeat
from torchvision.transforms import _functional_tensor as F_t


@contextlib.contextmanager
def eval_mode(module):
    istrain = module.training
    try:
        module.eval()
        yield module
    finally:
        if istrain:
            module.train()


def load_model_inference(
    path: str | Path,
    dtype=torch.bfloat16,
    device="cuda",
    port=random.randint(10000, 65535),
    strict=True,
) -> tuple[nn.Module, DictConfig]:
    import diffusion  # Needed for some interpolations

    if isinstance(path, str):
        path = Path(path)

    ckpt_path = path / "checkpoints"

    runs = os.listdir(ckpt_path)
    if len(runs) == 0:
        raise ValueError(f"No checkpoints found in {ckpt_path}")
    latest_run = max(runs, key=lambda x: int(x))
    ckpt_path = ckpt_path / latest_run

    with open(ckpt_path / "latest", mode="r") as f:
        tag = f.readlines()[0]
        print(f"Latest tag for {ckpt_path}: {tag}")

    cfg: DictConfig = OmegaConf.load(path / ".hydra" / "config.yaml")
    with open_dict(cfg):
        cfg.deepspeed.local_world_size = 1
        assert dtype in [torch.float32, torch.bfloat16]
        cfg.deepspeed.config.bfloat16.enabled = dtype == torch.bfloat16
        cfg.deepspeed_cli_args.local_rank = 0
        cfg.wandb.enabled = False
        cfg.deepspeed.config.wandb.enabled = False
        cfg.deepspeed.config.zero_optimization.stage = 0
        if hasattr(cfg.deepspeed.config, "csv_monitor"):
            cfg.deepspeed.config.csv_monitor.enabled = False

    OmegaConf.resolve(cfg)
    cfg_model = cfg.model
    model = hydra.utils.instantiate(cfg_model)
    model.eval()
    model.load_state_dict(
        torch.load(
            ckpt_path / tag / "inference.pt",
            map_location="cpu",
        ),
        strict=strict,
    )
    model.to(dtype)
    model.to(device)
    model.eval()
    model_engine, *_ = deepspeed.initialize(
        model=model,
        config_params=OmegaConf.to_container(cfg.deepspeed.config),
        distributed_port=port,
    )
    model = model_engine.module

    return model, cfg


def load_model_inference_direct(
    config_path: str | Path,
    checkpoint_path: str | Path,
    dtype=torch.bfloat16,
    device="cuda",
    port=random.randint(10000, 65535),
    strict=True,
    config_callback=None,
    initialize_deepspeed=True,
) -> tuple[nn.Module, DictConfig]:
    import diffusion  # Needed for some interpolations

    if isinstance(config_path, str):
        config_path = Path(config_path)
    if isinstance(checkpoint_path, str):
        checkpoint_path = Path(checkpoint_path)

    cfg: DictConfig = OmegaConf.load(config_path / ".hydra" / "config.yaml")
    with open_dict(cfg):
        cfg.deepspeed.local_world_size = 1
        assert dtype in [torch.float32, torch.bfloat16]
        cfg.deepspeed.config.bfloat16.enabled = dtype == torch.bfloat16
        cfg.deepspeed_cli_args.local_rank = 0
        cfg.wandb.enabled = False
        cfg.deepspeed.config.wandb.enabled = False
        cfg.deepspeed.config.zero_optimization.stage = 0
        if hasattr(cfg.deepspeed.config, "csv_monitor"):
            cfg.deepspeed.config.csv_monitor.enabled = False
        if config_callback is not None:
            config_callback(cfg)

    OmegaConf.resolve(cfg)
    cfg_model = cfg.model
    model = hydra.utils.instantiate(cfg_model)
    model.eval()
    model.load_state_dict(
        torch.load(
            checkpoint_path,
            map_location="cpu",
        ),
        strict=False,
    )
    model.to(dtype)
    model.to(device)
    model.eval()
    if initialize_deepspeed:
        model_engine, *_ = deepspeed.initialize(
            model=model,
            config_params=OmegaConf.to_container(cfg.deepspeed.config),
            distributed_port=port,
        )
        model = model_engine.module

    return model, cfg


class TimeMeasurement:
    def __init__(self, alpha=0.1):
        self.alpha = alpha
        self.ema = None

    def __enter__(self):
        self.start_time = time.time()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.end_time = time.time()
        elapsed_time = self.end_time - self.start_time
        if self.ema is None:
            self.ema = elapsed_time
        else:
            self.ema = self.alpha * elapsed_time + (1 - self.alpha) * self.ema

    def reset(self):
        self.ema = None


class NullObject:
    def __getattr__(self, name) -> "NullObject":
        return NullObject()

    def __call__(self, *args: Any, **kwds: Any) -> "NullObject":
        return NullObject()

    def __enter__(self) -> "NullObject":
        return self

    def __exit__(self, exc_type, exc_value, traceback) -> None:
        pass


def set_seed(seed=42, cuda=True):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed_all(seed)


def dict_to(d: Dict[str, Union[torch.Tensor, Any]], **to_kwargs) -> Dict[str, Union[torch.Tensor, Any]]:
    return {k: (v.to(**to_kwargs) if isinstance(v, torch.Tensor) else v) for k, v in d.items()}


# Taken from https://github.com/cloneofsimo/minRF/blob/main/advanced/main_t2i.py
# Thanks
# Apache 2.0 License
def _z3_params_to_fetch(param_list):
    return [p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE]


def get_zero_three_statedict(model, global_rank, zero_stage=0):
    zero_stage_3 = zero_stage == 3

    model_to_save = model.module if hasattr(model, "module") else model
    if not zero_stage_3:
        if global_rank == 0:
            return model_to_save.state_dict()
        else:
            return None
    else:
        output_state_dict = {}
        for k, v in model_to_save.named_parameters():
            if hasattr(v, "ds_id"):
                with deepspeed.zero.GatheredParameters(_z3_params_to_fetch([v]), enabled=zero_stage_3):
                    v_p = v.data.cpu()
            else:
                v_p = v.cpu()
            if global_rank == 0 and "lora" not in k:
                output_state_dict[k] = v_p
        if global_rank == 0:
            return output_state_dict
        else:
            return None


def save_zero_three_model(model, global_rank, output_model_file, zero_stage=0):
    zero_stage_3 = zero_stage == 3

    model_to_save = model.module if hasattr(model, "module") else model
    if not zero_stage_3:
        if global_rank == 0:
            torch.save(model_to_save.state_dict(), output_model_file)
    else:
        output_state_dict = {}
        for k, v in model_to_save.named_parameters():
            if hasattr(v, "ds_id"):
                with deepspeed.zero.GatheredParameters(_z3_params_to_fetch([v]), enabled=zero_stage_3):
                    v_p = v.data.cpu()
            else:
                v_p = v.cpu()
            if global_rank == 0 and "lora" not in k:
                output_state_dict[k] = v_p
        if global_rank == 0:
            torch.save(output_state_dict, output_model_file)
        del output_state_dict


# Taken from https://github.com/mhamilton723/FeatUp/blob/main/featup/util.py#L155
# Thanks
# MIT License
class TorchPCA(object):
    def __init__(self, n_components):
        self.n_components = n_components

    def fit(self, X):
        self.mean_ = X.mean(dim=0)
        unbiased = X - self.mean_.unsqueeze(0)
        U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=4)
        self.components_ = V.T
        self.singular_values_ = S
        return self

    def transform(self, X):
        t0 = X - self.mean_.unsqueeze(0)
        projected = t0 @ self.components_.T
        return projected


def pca(image_feats_list, dim=3, fit_pca=None, max_samples=None):
    device = image_feats_list[0].device

    def flatten(tensor, target_size=None):
        if target_size is not None and fit_pca is None:
            tensor = F.interpolate(tensor, (target_size, target_size), mode="bilinear")
        B, C, H, W = tensor.shape
        return tensor.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu()

    if len(image_feats_list) > 1 and fit_pca is None:
        target_size = image_feats_list[0].shape[2]
    else:
        target_size = None

    flattened_feats = []
    for feats in image_feats_list:
        flattened_feats.append(flatten(feats, target_size))
    x = torch.cat(flattened_feats, dim=0)

    # Subsample the data if max_samples is set and the number of samples exceeds max_samples
    if max_samples is not None and x.shape[0] > max_samples:
        indices = torch.randperm(x.shape[0])[:max_samples]
        x = x[indices]

    if fit_pca is None:
        fit_pca = TorchPCA(n_components=dim).fit(x)

    reduced_feats = []
    for feats in image_feats_list:
        x_red = fit_pca.transform(flatten(feats))
        if isinstance(x_red, np.ndarray):
            x_red = torch.from_numpy(x_red)
        x_red -= x_red.min(dim=0, keepdim=True).values
        x_red /= x_red.max(dim=0, keepdim=True).values
        B, C, H, W = feats.shape
        reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device))

    return reduced_feats, fit_pca


def get_scheduler(
    name: Union[str, SchedulerType],
    optimizer: torch.optim.Optimizer,
    num_warmup_steps: Optional[int] = None,
    num_training_steps: Optional[int] = None,
    scheduler_specific_kwargs: Optional[dict] = None,
):
    if name == "exponential_decay":

        def lr_lambda(current_step: int):
            return 0.5 ** ((current_step) / scheduler_specific_kwargs["t_decay"])

        return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    else:
        return _get_scheduler(
            name=name,
            optimizer=optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps,
            scheduler_specific_kwargs=scheduler_specific_kwargs,
        )


def normalize(frames):
    return frames.float() / 255.0


def apply_color_jitter(frames, params):
    rng = torch.Generator().manual_seed(int(time.time() * 1e3) % 2**32)

    OPS = dict(
        brightness=TF.adjust_brightness,
        contrast=TF.adjust_contrast,
        saturation=TF.adjust_saturation,
        hue=TF.adjust_hue,
    )

    # Build lambdas that capture the sampled value
    transforms = [
        (lambda x, v=_sample_param(rng, params[k], None), f=f: f(x, v)) for k, f in OPS.items() if k in params
    ]

    # Shuffle once with the same generator, then apply
    for idx in torch.randperm(len(transforms), generator=rng):
        frames = transforms[idx](frames)

    return frames


def affine(
    img, angle, translate, scale, shear, padding_mode="zeros"
):  # use border for replicating the last valid pixel
    b, c, h, w = img.shape
    matrix = TF._get_inverse_affine_matrix([0.0, 0.0], angle, [float(t) for t in translate], scale, shear)
    theta = torch.tensor(matrix, dtype=img.dtype, device=img.device).view(1, 2, 3)
    grid = F_t._gen_affine_grid(theta, w=w, h=h, ow=w, oh=h).expand(b, -1, -1, -1)
    return F.grid_sample(
        img,
        grid,
        mode="bilinear",
        padding_mode=padding_mode,
        align_corners=False,
    )


def perspective(img, startpoints, endpoints, padding_mode="zeros"):  # use border for replicating the last valid pixel
    ow, oh = img.shape[-1], img.shape[-2]
    coeffs = TF._get_perspective_coeffs(startpoints, endpoints)
    grid = F_t._perspective_grid(coeffs, ow=ow, oh=oh, dtype=img.dtype, device=img.device).expand(
        img.shape[0], -1, -1, -1
    )
    return F.grid_sample(
        img,
        grid,
        mode="bilinear",
        padding_mode=padding_mode,
        align_corners=False,
    )


def _sample_param(rng, rng_spec, default, n=1):
    """
    Draws * n * values uniformly from (lo, hi); returns default(s) if spec is None.
    """
    if rng_spec is None:
        return (default,) * n if n > 1 else default
    lo, hi = rng_spec
    draw = lambda: float(torch.rand((), generator=rng) * (hi - lo) + lo)
    return tuple(draw() for _ in range(n)) if n > 1 else draw()


def apply_geometric_transformations(frames, params):
    rng = torch.Generator().manual_seed(int(time.time() * 1e3) % 2**32)
    t, c, h, w = frames.shape

    if params.get("perspective", False):
        strength = _sample_param(rng, params.get("perspective"), 0.0)
        smooth = _sample_param(rng, params.get("smooth"), 0.0)
        std = strength * (h + w) / 2
        startpoints = torch.tensor([[0, 0], [w, 0], [0, h], [w, h]])

        if smooth > 0:
            startpoints = repeat(startpoints, "p c -> t p c", t=t)
            variation = std * torch.randn((t, 4, 2), generator=rng)
            kernel_size = (int(round(smooth * 6)) // 2) * 2 + 1
            gaussian_kernel = torch.exp(-((torch.arange(kernel_size) - (kernel_size - 1) / 2) ** 2) / (2 * smooth**2))
            gaussian_kernel /= gaussian_kernel.sum()
            variation = rearrange(
                F.conv1d(
                    rearrange(variation, "t p c -> (p c) 1 t"),
                    repeat(gaussian_kernel, "l -> 1 1 l"),
                    padding=kernel_size // 2,
                ),
                "(p c) 1 t -> t p c",
                p=4,
            )
            endpoints = startpoints + variation
            frames = torch.cat(
                [
                    perspective(
                        f.unsqueeze(0),
                        startpoints[i],
                        endpoints[i],
                        padding_mode="border" if params.get("remove_padding") else "zeros",
                    )
                    for i, f in enumerate(frames)
                ],
                dim=0,
            )
        else:
            variation = std * torch.randn((4, 2), generator=rng)
            endpoints = startpoints + variation
            frames = perspective(
                frames, startpoints, endpoints, padding_mode="border" if params.get("remove_padding") else "zeros"
            )
    else:
        angle = _sample_param(rng, params.get("angle"), 0.0)
        scale = _sample_param(rng, params.get("scale"), 1.0)
        shear = _sample_param(rng, params.get("shear"), 0.0)
        tx, ty = _sample_param(rng, params.get("translate"), 0.0, n=2)

        frames = affine(
            frames,
            angle=angle,
            translate=(tx * w, ty * h),
            scale=scale,
            shear=[shear, 0],  # skew in x-direction only
            padding_mode="border" if params.get("remove_padding") else "zeros",
        )

    if params.get("flip", False) and random.random() > 0.5:
        frames = TF.hflip(frames)

    return frames


def transform(frames, params_jitter=None, params_geometric=None):
    # [-1, 1] to [0, 1]
    frames = (frames + 1) / 2

    if params_jitter is not None and params_jitter.enabled:
        frames = apply_color_jitter(frames, params_jitter)

    if params_geometric is not None and params_geometric.enabled:
        frames = apply_geometric_transformations(frames, params_geometric)
        frames = apply_geometric_transformations(frames, params_geometric)

    # [0, 1] to [-1, 1]
    frames = (frames * 2) - 1
    return frames


def sample_mask(
    max_tries=10, max_temporal_keep=1.0, mask_ratio=0.9, mask_type="tube", fs=1, hs=16, ws=16, batch_size=1
):
    keep_ratio = 1 - mask_ratio
    if mask_type == "random":
        # Create different masks for each sample in the batch
        masks = []
        for _ in range(batch_size):
            masks.append(torch.randperm(fs * hs * ws)[: int(keep_ratio * fs * hs * ws)])
        return masks
    elif mask_type == "tube":
        masks = []
        for _ in range(batch_size):
            mask = torch.ones(hs * ws)
            mask[torch.randperm(hs * ws)[: int(mask_ratio * hs * ws)]] = 0
            mask = mask[None, :].repeat(fs, 1)
            masks.append(mask.reshape(fs * hs * ws).nonzero()[:, 0])
        return masks

    elif mask_type == "vjepa":
        masks = []
        for _ in range(batch_size):

            # decide if we have a long or short scale_mask (take values from VJEPA)
            p = np.random.random()
            spatial_scale = (0.15, 0.15) if p < 0.5 else (0.7, 0.7)
            temporal_scale = (1.0, 1.0) if p < 0.5 else (1.0, 1.0)
            aspect_ratio_scale = (0.75, 1.5) if p < 0.5 else (0.75, 1.5)
            nblocks = 8 if p < 0.5 else 2

            # -- Sample temporal block mask scale
            _rand = np.random.random()
            min_t, max_t = temporal_scale
            temporal_mask_scale = min_t + _rand * (max_t - min_t)
            t = max(1, int(fs * temporal_mask_scale))

            # -- Sample spatial block mask scale
            _rand = np.random.random()
            min_s, max_s = spatial_scale
            spatial_mask_scale = min_s + _rand * (max_s - min_s)
            spatial_num_keep = int(hs * ws * spatial_mask_scale)

            # -- Sample block aspect-ratio
            _rand = np.random.random()
            min_ar, max_ar = aspect_ratio_scale
            aspect_ratio = min_ar + _rand * (max_ar - min_ar)

            # -- Compute block height and width (given scale and aspect-ratio)
            h = int(round(np.sqrt(spatial_num_keep * aspect_ratio)))
            w = int(round(np.sqrt(spatial_num_keep / aspect_ratio)))
            h = min(h, hs)
            w = min(w, ws)

            # now sample masks
            max_context_duration = max(1, int(fs * max_temporal_keep))

            empty_context = True
            num_tries = 0
            while empty_context:
                mask = np.ones((1, fs, hs, ws))
                for _ in range(nblocks):
                    top = np.random.randint(0, hs - h + 1)
                    left = np.random.randint(0, ws - w + 1)
                    start = np.random.randint(0, fs - t + 1)

                    mask_temp = np.ones((1, fs, hs, ws))
                    mask_temp[:, start : start + t, top : top + h, left : left + w] = 0

                    # Context mask will only span the first X frames
                    # (X=self.max_context_frames)
                    if max_context_duration < fs:
                        mask_temp[max_context_duration:, :, :] = 0
                    mask *= mask_temp
                # check if mask is valid
                empty_context = int(np.sum(mask)) == 0

                num_tries += 1
                if num_tries == max_tries:
                    mask = np.random.binomial(1, 0.5, size=(1, fs, hs, ws))  # ok it failed just sample a random mask
                    break

            mask = torch.from_numpy(mask)
            mask = mask.reshape(fs, hs * ws)
            keep_idcs_hw = mask[0].nonzero()[:, 0]
            r = torch.rand(hs * ws)
            r[keep_idcs_hw] = -1
            keep_idcs_hw = r.sort()[1][: int(keep_ratio * hs * ws)]
            mask = mask.new_zeros((fs, hs * ws))
            mask[:, keep_idcs_hw] = 1
            # if mask[0].count_nonzero() < int(self.keep_ratio * hs * ws):
            #     mask_idcs = (mask[0] == 0).nonzero()[:, 0]
            masks.append(mask.reshape(fs * hs * ws).nonzero()[:, 0])
        return masks
