import math
from typing import Callable

import torch
from einops import rearrange, repeat
from torch import Tensor

from .model import Flux
from .modules.conditioner import HFEmbedder

import numpy as np
from typing import Union, List
from tqdm import tqdm

def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
    bs, c, h, w = img.shape
    if bs == 1 and not isinstance(prompt, str):
        bs = len(prompt)

    img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
    if img.shape[0] == 1 and bs > 1:
        img = repeat(img, "1 ... -> bs ...", bs=bs)

    img_ids = torch.zeros(h // 2, w // 2, 3)
    img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
    img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
    img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)

    if isinstance(prompt, str):
        prompt = [prompt]
    txt = t5(prompt)
    if txt.shape[0] == 1 and bs > 1:
        txt = repeat(txt, "1 ... -> bs ...", bs=bs)
    txt_ids = torch.zeros(bs, txt.shape[1], 3)

    vec = clip(prompt)
    if vec.shape[0] == 1 and bs > 1:
        vec = repeat(vec, "1 ... -> bs ...", bs=bs)

    return {
        "img": img,
        "img_ids": img_ids.to(img.device),
        "txt": txt.to(img.device),
        "txt_ids": txt_ids.to(img.device),
        "vec": vec.to(img.device),
    }


def time_shift(mu: float, sigma: float, t: Tensor):
    return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)


def get_lin_function(
    x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
    m = (y2 - y1) / (x2 - x1)
    b = y1 - m * x1
    return lambda x: m * x + b


def get_schedule(
    num_steps: int,
    image_seq_len: int,
    base_shift: float = 0.5,
    max_shift: float = 1.15,
    shift: bool = True,
) -> list[float]:
    # extra step for zero
    timesteps = torch.linspace(1, 0, num_steps + 1)

    # shifting the schedule to favor high timesteps for higher signal images
    if shift:
        # estimate mu based on linear estimation between two points
        mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
        timesteps = time_shift(mu, 1.0, timesteps)

    return timesteps.tolist()


def denoise(
    model: Flux,
    img: Tensor,
    img_ids: Tensor,
    txt: Tensor,
    txt_ids: Tensor,
    vec: Tensor,
    timesteps: list[float],
    inverse,
    info, 
    guidance: Union[float, List[float], Tensor]
):
     
    if inverse:
        timesteps = timesteps[::-1]
             
    if isinstance(guidance, (float, int)):
        guidance_vec = torch.tensor([guidance], device=img.device, dtype=img.dtype)  # 标量 → 1维张量
    elif isinstance(guidance, (list, tuple)):
        guidance_vec = torch.tensor(guidance, device=img.device, dtype=img.dtype)
    elif isinstance(guidance, Tensor):
        guidance_vec = guidance.to(device=img.device, dtype=img.dtype)
    else:
        raise TypeError("guidance 必须是 float、列表或张量")

    scores_t = []
    
    ###############################################################
    # for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
    for i, (t_curr, t_prev) in enumerate(tqdm(zip(timesteps[:-1], timesteps[1:]), total=len(timesteps)-1)):
     
        t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
        info['t'] = t_prev if inverse else t_curr
        info['inverse'] = inverse
        if info['inverse']:
            info['t_step'] = info['num_steps'] -1 -i
        else:
            info['t_step'] = i
        i_1 = i+1
        if not info['inverse']:
            loss_z = info['latents_list'][-i_1][0] - img[0]
            
            error_t_f2_norm = torch.norm(loss_z, p='fro')
            info['lambda_t'] = torch.sigmoid(1 * error_t_f2_norm)
            # arctan_value = torch.atan(0.5 * lambda_t_f2_norm)  
            # info['lambda_t'] = (arctan_value / (torch.pi / 2) + 1) / 2 
            
            img[0] = info['latents_list'][-i_1][0]
            
            if info['t_step'] < 10:
                img[1] += loss_z
            
        pred, info, idx = model(
            img=img,
            img_ids=img_ids,
            txt=txt,
            txt_ids=txt_ids,
            y=vec,
            timesteps=t_vec,
            guidance=guidance_vec,
            info=info
        )

        img = img + (t_prev - t_curr) * pred
        
        if info['inverse']:
            scores_t.append(idx)
            info['latents_list'].append(img)
        #############################################
    
    if info['inverse']:
        vatal_layer_idx = torch.stack(scores_t)
        info['inject_idx'] = vatal_layer_idx   
                
    return img, info


def unpack(x: Tensor, height: int, width: int) -> Tensor:
    return rearrange(
        x,
        "b (h w) (c ph pw) -> b c (h ph) (w pw)",
        h=math.ceil(height / 16),
        w=math.ceil(width / 16),
        ph=2,
        pw=2,
    )
