import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict
import copy

class UCEEditor:
    """
    A corrected and clean implementation of Unified Concept Editing (UCE),
    based on the user's working `ucebkp.py` script.
    """
    
    def __init__(self, model, device='cuda'):
        self.model = model # Assuming model is the full pipeline
        self.device = device
        self.torch_dtype = self.model.dtype
        self._store_original_weights()

    def _get_text_embeddings(self, prompt: str) -> torch.Tensor:
        """
        Gets the text embedding for the LAST meaningful token of a prompt string.
        """
        # The model's cond_stage_model is the FrozenCLIPEmbedder, which has the tokenizer
        tokenizer = self.model.cond_stage_model.tokenizer
        
        # Tokenize the prompt to get the input_ids and attention_mask
        text_inputs = tokenizer(
            prompt,
            padding="max_length",
            max_length=tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        
        # Find the index of the last token that is not a padding token.
        # We subtract 1 because the indices are 0-based.
        last_token_idx = text_inputs.attention_mask.sum(dim=-1).item() - 1

        # The <eos> token is usually the last token, but to be robust like the
        # original implementation, we take the one *before* it.
        last_meaningful_token_idx = last_token_idx - 1

        # Get the full sequence of token embeddings from the model
        with torch.no_grad():
            # The FrozenCLIPEmbedder itself is callable to get embeddings
            token_embeddings = self.model.cond_stage_model(prompt)
            # print(token_embeddings.shape) # [1, 77, 768]

        # Extract the embedding for only the last meaningful token
        concept_embed = token_embeddings[:, last_meaningful_token_idx, :]
        return concept_embed

    def _get_attention_layers(self) -> Dict[str, nn.Module]:
        """Finds all target attention layers (to_k and to_v in cross-attention)."""
        attn_layers = {}
        for name, module in self.model.named_modules():
            if 'attn2' in name and (name.endswith('to_v') or name.endswith('to_k')):
                print(name, module.weight.shape)
                attn_layers[name] = module
        return attn_layers

    def _store_original_weights(self):
        """Stores a copy of the original attention weights for restoration."""
        self.original_weights = {
            name: layer.weight.detach().clone()
            for name, layer in self._get_attention_layers().items()
        }
        print(f"UCE: Backed up original weights for {len(self.original_weights)} layers.")

    @torch.no_grad()
    def erase_concept(
        self,
        edit_concepts: List[str],
        guide_concepts: List[str],
        preserve_concepts: List[str],
        erase_scale: float = 1.0,
        preserve_scale: float = 1.0,
        lamb: float = 0.5,
    ):
        """
        Performs concept erasure using the correct closed-form UCE update.
        """
        print(f"UCE: Starting erasure. Erasing: {edit_concepts}, Guiding: {guide_concepts}")

        # --- Step 1: Pre-calculate all necessary embeddings ---
        all_concepts = edit_concepts + guide_concepts + preserve_concepts
        concept_embeds = {e: self._get_text_embeddings(e) for e in all_concepts}
        
        # --- Step 2: Pre-calculate guide outputs from the original model ---
        original_layers = self._get_attention_layers()
        guide_outputs = {}
        for concept in guide_concepts + preserve_concepts:
            t_emb = concept_embeds[concept]
            guide_outputs[concept] = [module(t_emb) for module in original_layers.values()]
        
        # --- Step 3: Iterate through layers and apply the UCE update ---
        target_layers = self._get_attention_layers()
        for i, (name, layer) in enumerate(target_layers.items()):
            w_old = layer.weight.data.clone().to(self.device, dtype=self.torch_dtype)

            # Initialize matrices for the UCE equation
            mat1 = lamb * w_old
            mat2 = lamb * torch.eye(w_old.shape[1], device=self.device, dtype=self.torch_dtype)
            
            # Add erasure terms
            for erase_concept, guide_concept in zip(edit_concepts, guide_concepts):
                c_i = concept_embeds[erase_concept].T
                v_i_star = guide_outputs[guide_concept][i].T
                mat1 += erase_scale * (v_i_star @ c_i.T)
                mat2 += erase_scale * (c_i @ c_i.T)
                
            # Add preservation terms
            for preserve_concept in preserve_concepts:
                c_i = concept_embeds[preserve_concept].T
                v_i_star = guide_outputs[preserve_concept][i].T
                mat1 += preserve_scale * (v_i_star @ c_i.T)
                mat2 += preserve_scale * (c_i @ c_i.T)

            # Compute and set the new weights
            try:
                mat2_inv = torch.inverse(mat2.float()).to(self.torch_dtype)
                W_new = mat1 @ mat2_inv
                layer.weight.data = W_new.to(w_old.dtype)
            except torch.linalg.LinAlgError:
                print(f"Warning: Matrix inversion failed for layer {name}, using pseudoinverse.")
                mat2_inv = torch.pinverse(mat2.float()).to(self.torch_dtype)
                W_new = mat1 @ mat2_inv
                layer.weight.data = W_new.to(w_old.dtype)

        print(f"UCE editing complete. Modified {len(target_layers)} layers.")

    def restore_original_weights(self):
        """Restores the original weights to the model."""
        target_layers = self._get_attention_layers()
        for name, layer in target_layers.items():
            if name in self.original_weights:
                layer.weight.data.copy_(self.original_weights[name])
        print("Original model weights restored.")