import torch
import torch.nn as nn
from enum import Enum, auto
from contextlib import contextmanager
from lora_diffusion import LoraInjectedLinear, LoraInjectedConv2d

class PipelineMode(Enum):
    GENERATION = auto()
    INVERSION = auto()

class FARILinear(nn.Module):
    def __init__(self, lora_layer):
        super().__init__()
        self.lora_layer = lora_layer
        self.mode = PipelineMode.GENERATION

    def forward(self, input):
        if self.mode == PipelineMode.GENERATION:
            return self.lora_layer.linear(input)
        elif self.mode == PipelineMode.INVERSION:
            return self.lora_layer(input)
        else:
            raise ValueError(f"Unknown mode: {self.mode}")

class FARIConv2d(nn.Module):
    def __init__(self, lora_layer):
        super().__init__()
        self.lora_layer = lora_layer
        self.mode = PipelineMode.GENERATION

    def forward(self, input):
        if self.mode == PipelineMode.GENERATION:
            return self.lora_layer.conv(input)
        elif self.mode == PipelineMode.INVERSION:
            return self.lora_layer(input)
        else:
            raise ValueError(f"Unknown mode: {self.mode}")

@contextmanager
def fari_mode(pipe, mode):
    target_classes = (FARILinear, FARIConv2d) 
    original_modes = {}
    try:
        for module in pipe.unet.modules():
            if isinstance(module, target_classes):
                original_modes[module] = module.mode
                module.mode = mode
        yield
    finally:
        for module, original_mode in original_modes.items():
            module.mode = original_mode

def inject_fari(model):
    for name, module in model.named_modules():
        if isinstance(module, LoraInjectedLinear):
            parent_module = model
            name_parts = name.split('.')
            for part in name_parts[:-1]:
                parent_module = getattr(parent_module, part)
            setattr(parent_module, name_parts[-1], FARILinear(module))
        elif isinstance(module, LoraInjectedConv2d):
            parent_module = model
            name_parts = name.split('.')
            for part in name_parts[:-1]:
                parent_module = getattr(parent_module, part)
            setattr(parent_module, name_parts[-1], FARIConv2d(module))
    
    return model

def save_fari_model(model, path):
    fari_params = {}
    for name, params in model.state_dict().items():
        if "lora_up" in name or "lora_down" in name:
            fari_params[name] = params
    torch.save(fari_params, path)

def one_step_inversion(sd_pipe, latents, prompt_embeds):
    latent_model_input = latents
    latent_model_input = sd_pipe.scheduler.scale_model_input(latent_model_input, 0)

    with fari_mode(sd_pipe, PipelineMode.INVERSION):
        noise_pred = sd_pipe.unet(latent_model_input, 0, encoder_hidden_states=prompt_embeds)[0]
    
    alpha_prod_T = sd_pipe.scheduler.alphas_cumprod[-1]
    latents = alpha_prod_T.sqrt() * latents + (1 - alpha_prod_T).sqrt() * noise_pred
    return latents