from dataclasses import dataclass
import torch
from torch import Tensor, nn
import types
from typing import Dict
from flux.model import Flux
from flux.util import configs, hf_hub_download, load_sft, optionally_expand_state_dict, print_load_warning
from scaling_cache.cache_functions import cache_init, cache_release, cal_type
from scaling_cache.forwards.flux import enhance_model_forwards

def load_flow_model(
    name: str, mode: str, device: str | torch.device = "cuda", hf_download: bool = True, verbose: bool = False, update_alpha: bool = False, dynamic_cache: bool = False, use_alpha: bool = False, first_enhance: int =10
) -> Flux:
    # Loading Flux
    ckpt_path = None
    lora_path = None
    if (
        ckpt_path is None
        and configs[name].repo_id is not None
        and configs[name].repo_flow is not None
        and hf_download
    ):
        ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)

    with torch.device("meta" if ckpt_path is not None else device):
        if lora_path is not None:
            model = FluxLoraWrapper(params=configs[name].params).to(torch.bfloat16)
        else:
            model = Flux(configs[name].params).to(torch.bfloat16)

    if ckpt_path is not None:
        # load_sft doesn't support torch.device
        sd = load_sft(ckpt_path, device=str(device))
        sd = optionally_expand_state_dict(model, sd)
        missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
        if verbose:
            print_load_warning(missing, unexpected)

    if lora_path is not None:
        print("Loading LoRA")
        lora_sd = load_sft(lora_path, device=str(device))
        # loading the lora params + overwriting scale values in the norms
        missing, unexpected = model.load_state_dict(lora_sd, strict=False, assign=True)
        if verbose:
            print_load_warning(missing, unexpected)

    
    model.mode = mode
    model.task = "flux-dev"
    model.dynamic_cache = dynamic_cache
    model.update_alpha = update_alpha
    model.use_alpha = use_alpha
    model.first_enhance = first_enhance

    model.cal_type = types.MethodType(cal_type, model)
    model.cache_init = types.MethodType(cache_init, model)
    model.cache_release = types.MethodType(cache_release, model)
    model.enhance_model_forwards = types.MethodType(enhance_model_forwards, model)
    model.cache_init()
    model.enhance_model_forwards()
    
    model.cache_init()
    return model