import math
from typing import Callable
import numpy as np
import torch
from einops import rearrange, repeat
from torch import Tensor

from .model import Flux
from .modules.conditioner import HFEmbedder


def norm2(img, iimg, mean=None):
    shape = img.shape
    img = img.reshape(-1, 1)
    if mean is not None:
        mimg = iimg + mean
    else:
        mimg = iimg
    mimg = mimg.reshape(-1, 1)
    norm = 0.5 * torch.pow(img - mimg, 2)
    norm = norm.reshape(shape)
    return norm

def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
    bs, c, h, w = img.shape
    if bs == 1 and not isinstance(prompt, str):
        bs = len(prompt)

    img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
    if img.shape[0] == 1 and bs > 1:
        img = repeat(img, "1 ... -> bs ...", bs=bs)

    img_ids = torch.zeros(h // 2, w // 2, 3)
    img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
    img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
    img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)

    if isinstance(prompt, str):
        prompt = [prompt]
    txt = t5(prompt)
    if txt.shape[0] == 1 and bs > 1:
        txt = repeat(txt, "1 ... -> bs ...", bs=bs)
    txt_ids = torch.zeros(bs, txt.shape[1], 3)

    vec = clip(prompt)
    if vec.shape[0] == 1 and bs > 1:
        vec = repeat(vec, "1 ... -> bs ...", bs=bs)

    return {
        "img": img,
        "img_ids": img_ids.to(img.device),
        "txt": txt.to(img.device),
        "txt_ids": txt_ids.to(img.device),
        "vec": vec.to(img.device),
    }


def time_shift(mu: float, sigma: float, t: Tensor):
    return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)


def get_lin_function(
    x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
    m = (y2 - y1) / (x2 - x1)
    b = y1 - m * x1
    return lambda x: m * x + b


def get_schedule(
    num_steps: int,
    image_seq_len: int,
    base_shift: float = 0.5,
    max_shift: float = 1.15,
    shift: bool = True,
) -> list[float]:
    # extra step for zero
    timesteps = torch.linspace(1, 0, num_steps + 1)

    # shifting the schedule to favor high timesteps for higher signal images
    if shift:
        # estimate mu based on linear estimation between two points
        mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
        timesteps = time_shift(mu, 1.0, timesteps)

    return timesteps.tolist()


def denoise(
    model: Flux,
    # model input
    img: Tensor,
    img_ids: Tensor,
    txt: Tensor,
    txt_ids: Tensor,
    vec: Tensor,
    # sampling parameters
    timesteps: list[float],
    inverse,
    info, 
    guidance: float = 4.0
):
    # this is ignored for schnell
    inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step'])

    if inverse:
        timesteps = timesteps[::-1]
        inject_list = inject_list[::-1]
    guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)

    step_list = []
    for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
        t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
        info['t'] = t_prev if inverse else t_curr
        info['inverse'] = inverse
        info['second_order'] = False
        info['inject'] = inject_list[i]
        with torch.no_grad():
            pred, info = model(
                img=img,
                img_ids=img_ids,
                txt=txt,
                txt_ids=txt_ids,
                y=vec,
                timesteps=t_vec,
                guidance=guidance_vec,
                info=info
            )
        
        img = img + (t_prev - t_curr) * pred


    return img, info


def denoise_m(
        model: Flux,
        # model input
        img: Tensor,
        img_ids: Tensor,
        txt: Tensor,
        txt_ids: Tensor,
        vec: Tensor,
        # sampling parameters
        timesteps: list[float],
        inverse,
        info,
        guidance: float = 4.0
):
    # this is ignored for schnell
    inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step'])

    if inverse:
        timesteps = timesteps[::-1]
        inject_list = inject_list[::-1]
    guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)

    mean_velocity = None
    next_step_velocity = None
    step_list = []
    for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
        t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
        info['t'] = t_prev if inverse else t_curr
        info['inverse'] = inverse
        info['second_order'] = False
        info['inject'] = inject_list[i]
        with torch.no_grad():
            pred, info = model(
                img=img,
                img_ids=img_ids,
                txt=txt,
                txt_ids=txt_ids,
                y=vec,
                timesteps=t_vec,
                guidance=guidance_vec,
                info=info
            )
        if mean_velocity is None:
            mean_velocity = (t_prev - t_curr) * pred
        else:
            mean_velocity = mean_velocity + (t_prev - t_curr) * pred

        w = 0.94
        if info['recon'] is not True:
            with torch.enable_grad():
                pred = pred.detach().requires_grad_(True)
                pred_i = pred
                pred_mean = (1 / (t_prev if inverse else t_prev - 1)) * mean_velocity
                if inverse:
                    if next_step_velocity is not None:
                        fx = (pred - next_step_velocity).abs() + norm2(pred, pred_mean)
                    else:
                        fx = norm2(pred, pred_mean)
                    l = fx.sum()
                    grad = torch.autograd.grad(l, pred, retain_graph=True)[0]
                    pred = pred - np.sqrt(2*22+3*np.sqrt(2*22)) * (t_prev - t_curr) * (grad / torch.norm(grad, 2, dim=[1, 2]))
                else:
                    s =(pred_i*pred_mean)/(torch.norm(pred_mean, 2, dim=[1, 2])**2)
                    pred = (1-w)*pred_mean*s+w*pred

                pred.grad = None
                pred._grad_fn = None

        else:
            with torch.enable_grad():
                pred = pred.detach().requires_grad_(True)
                pred_mean = (1 / (t_prev if inverse else t_prev - 1)) * mean_velocity
                if next_step_velocity is not None:
                    fx = (pred - next_step_velocity).abs() + norm2(pred, pred_mean)
                else:
                    fx = norm2(pred, pred_mean)
                l = fx.sum()
                grad = torch.autograd.grad(l, pred, retain_graph=True)[0]
                pred = pred - np.sqrt(2*22+3*np.sqrt(2*22)) * (t_prev - t_curr) * (grad / torch.norm(grad, 2, dim=[1, 2]))
                pred.grad = None
                pred._grad_fn = None

        next_step_velocity = pred
        img = img + (t_prev - t_curr) * pred
        if mean_velocity is not None:
            mean_velocity = mean_velocity + (t_prev - t_curr) * pred
    return img, info

def denoise_rf_solver(
    model: Flux,
    # model input
    img: Tensor,
    img_ids: Tensor,
    txt: Tensor,
    txt_ids: Tensor,
    vec: Tensor,
    # sampling parameters
    timesteps: list[float],
    inverse,
    info, 
    guidance: float = 4.0
):
    # this is ignored for schnell
    inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step'])

    if inverse:
        timesteps = timesteps[::-1]
        inject_list = inject_list[::-1]
    guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)

    step_list = []
    for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
        t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
        info['t'] = t_prev if inverse else t_curr
        info['inverse'] = inverse
        info['second_order'] = False
        info['inject'] = inject_list[i]
        with torch.no_grad():
            pred, info = model(
                img=img,
                img_ids=img_ids,
                txt=txt,
                txt_ids=txt_ids,
                y=vec,
                timesteps=t_vec,
                guidance=guidance_vec,
                info=info
            )

        img_mid = img + (t_prev - t_curr) / 2 * pred

        t_vec_mid = torch.full((img.shape[0],), (t_curr + (t_prev - t_curr) / 2), dtype=img.dtype, device=img.device)
        info['second_order'] = True
        with torch.no_grad():
            pred_mid, info = model(
                img=img_mid,
                img_ids=img_ids,
                txt=txt,
                txt_ids=txt_ids,
                y=vec,
                timesteps=t_vec_mid,
                guidance=guidance_vec,
                info=info
            )

        first_order = (pred_mid - pred) / ((t_prev - t_curr) / 2)
        img = img + (t_prev - t_curr) * pred + 0.5 * (t_prev - t_curr) ** 2 * first_order


    return img, info

def denoise_fireflow(
    model: Flux,
    # model input
    img: Tensor,
    img_ids: Tensor,
    txt: Tensor,
    txt_ids: Tensor,
    vec: Tensor,
    # sampling parameters
    timesteps: list[float],
    inverse,
    info, 
    guidance: float = 4.0
):
    # this is ignored for schnell
    inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step'])

    if inverse:
        timesteps = timesteps[::-1]
        inject_list = inject_list[::-1]
    guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)

    step_list = []
    next_step_velocity = None
    for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
        t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
        info['t'] = t_prev if inverse else t_curr
        info['inverse'] = inverse
        info['second_order'] = False
        info['inject'] = inject_list[i]
        with torch.no_grad():
            if next_step_velocity is None:
                pred, info = model(
                    img=img,
                    img_ids=img_ids,
                    txt=txt,
                    txt_ids=txt_ids,
                    y=vec,
                    timesteps=t_vec,
                    guidance=guidance_vec,
                    info=info
                )
            else:
                pred = next_step_velocity
        
        img_mid = img + (t_prev - t_curr) / 2 * pred

        t_vec_mid = torch.full((img.shape[0],), t_curr + (t_prev - t_curr) / 2, dtype=img.dtype, device=img.device)
        info['second_order'] = True
        with torch.no_grad():
            pred_mid, info = model(
                img=img_mid,
                img_ids=img_ids,
                txt=txt,
                txt_ids=txt_ids,
                y=vec,
                timesteps=t_vec_mid,
                guidance=guidance_vec,
                info=info
            )
            next_step_velocity = pred_mid
        
        img = img + (t_prev - t_curr) * pred_mid

    return img, info


def denoise_fireflow_M(
        model: Flux,
        # model input
        img: Tensor,
        img_ids: Tensor,
        txt: Tensor,
        txt_ids: Tensor,
        vec: Tensor,
        # sampling parameters
        timesteps: list[float],
        inverse,
        info,
        guidance: float = 4.0,

):





    # this is ignored for schnell
    inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step'])

    if inverse:
        timesteps = timesteps[::-1]
        inject_list = inject_list[::-1]
    guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)

    step_list = []
    next_step_velocity = None
    mean_velocity = None
    for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):

        # initial_img = img.detach().requires_grad_(True)
        t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
        info['t'] = t_prev if inverse else t_curr
        info['inverse'] = inverse
        info['second_order'] = False
        info['inject'] = inject_list[i]

        with torch.no_grad():
            if next_step_velocity is None:
                pred, info = model(
                    img=img,
                    img_ids=img_ids,
                    txt=txt,
                    txt_ids=txt_ids,
                    y=vec,
                    timesteps=t_vec,
                    guidance=guidance_vec,
                    info=info
                )
            else:
                pred = next_step_velocity

        img_mid = img + (t_prev - t_curr) / 2 * pred

        t_vec_mid = torch.full((img.shape[0],), t_curr + (t_prev - t_curr) / 2, dtype=img.dtype, device=img.device)
        info['second_order'] = True
        with torch.no_grad():
            pred_mid, info = model(
                img=img_mid,
                img_ids=img_ids,
                txt=txt,
                txt_ids=txt_ids,
                y=vec,
                timesteps=t_vec_mid,
                guidance=guidance_vec,
                info=info
            )

        # distance
        if next_step_velocity is None:
            mean_velocity = (t_prev - t_curr) * pred_mid
        else:
            mean_velocity = mean_velocity + (t_prev - t_curr) * pred_mid

        # guidance to the velocity
        w = 0.94
        if not info['recon']:
            with torch.enable_grad():
                pred_mid = pred_mid.detach().requires_grad_(True)
                pred_i = pred_mid
                pred_mean = (1 / (t_prev if inverse else t_prev - 1)) * mean_velocity
                if inverse:
                    if next_step_velocity is not None:
                        fx = (pred_mid-next_step_velocity).abs()+norm2(pred_mid, pred_mean)
                    else:
                        fx = norm2(pred_mid, pred_mean)
                    l = fx.sum()
                    grad = torch.autograd.grad(l, pred_mid, retain_graph=True)[0]
                    # pred_mid = pred_mid -0.0001*grad
                    pred_mid = pred_mid - np.sqrt(2*22+3*np.sqrt(2*22)) * (t_prev - t_curr) * (grad / torch.norm(grad, 2, dim=[1, 2]))
                else:
                    s =(pred_i*pred_mean)/(torch.norm(pred_i, 2, dim=[1, 2])**2)
                    pred_mid = (1-w)*pred_mean*s+w*pred_mid

                pred_mid.grad = None
                pred_mid._grad_fn = None

        else:
            with torch.enable_grad():
                pred_mid = pred_mid.detach().requires_grad_(True)
                pred_mean = (1 / (t_prev if inverse else t_prev - 1)) * mean_velocity
                if next_step_velocity is not None:
                    fx = (pred_mid - next_step_velocity).abs() + norm2(pred_mid, pred_mean)
                else:
                    fx = norm2(pred_mid, pred_mean)
                l = fx.sum()
                grad = torch.autograd.grad(l, pred_mid, retain_graph=True)[0]
                pred_mid = pred_mid - np.sqrt(2*22+3*np.sqrt(2*22)) * (t_prev - t_curr) * (grad / torch.norm(grad, 2, dim=[1, 2]))
                pred_mid.grad = None
                pred_mid._grad_fn = None

        next_step_velocity = pred_mid

        # initial_img = torch.clone(img)


        # guidance to the data
        img = img + (t_prev - t_curr) * pred_mid

    return img, info





def denoise_midpoint(
    model: Flux,
    # model input
    img: Tensor,
    img_ids: Tensor,
    txt: Tensor,
    txt_ids: Tensor,
    vec: Tensor,
    # sampling parameters
    timesteps: list[float],
    inverse,
    info, 
    guidance: float = 4.0
):
    # this is ignored for schnell
    inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step'])

    if inverse:
        timesteps = timesteps[::-1]
        inject_list = inject_list[::-1]
    guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)

    step_list = []
    for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
        t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
        info['t'] = t_prev if inverse else t_curr
        info['inverse'] = inverse
        info['second_order'] = False
        info['inject'] = inject_list[i]
        with torch.no_grad():
            pred, info = model(
                img=img,
                img_ids=img_ids,
                txt=txt,
                txt_ids=txt_ids,
                y=vec,
                timesteps=t_vec,
                guidance=guidance_vec,
                info=info
            )
        
        img_mid = img + (t_prev - t_curr) / 2 * pred

        t_vec_mid = torch.full((img.shape[0],), t_curr + (t_prev - t_curr) / 2, dtype=img.dtype, device=img.device)
        info['second_order'] = True
        with torch.no_grad():
            pred_mid, info = model(
                img=img_mid,
                img_ids=img_ids,
                txt=txt,
                txt_ids=txt_ids,
                y=vec,
                timesteps=t_vec_mid,
                guidance=guidance_vec,
                info=info
            )
        next_step_velocity = pred_mid
        
        img = img + (t_prev - t_curr) * pred_mid

    return img, info




def unpack(x: Tensor, height: int, width: int) -> Tensor:
    return rearrange(
        x,
        "b (h w) (c ph pw) -> b c (h ph) (w pw)",
        h=math.ceil(height/16),
        w=math.ceil(width/16),
        ph=2,
        pw=2,
    )
