import numpy as np
import torch
from transformers import CLIPModel, AutoProcessor, CLIPTokenizer, CLIPTextModel, CLIPProcessor
from sklearn.decomposition import TruncatedSVD

def zero_shot_classification(img_embeds, clf_embeds, device='cuda'):
    """
    Perform zero-shot classification by computing the similarity between image embeddings and classifier embeddings.

    Args:
        img_embeds (torch.Tensor): Embeddings representing the input images.
        clf_embeds (torch.Tensor): Embeddings representing the classifier (class names).
        device (str): Device to perform computation on, default is 'cuda'.

    Returns:
        tuple: 
            - numpy.ndarray: Index of the class with the highest similarity.
            - torch.Tensor: The similarity matrix between image and class embeddings.
    """
    # Normalize the features
    img_embeds = img_embeds / img_embeds.norm(p=2, dim=-1, keepdim=True)
    clf_embeds = clf_embeds / clf_embeds.norm(p=2, dim=-1, keepdim=True)
    
    # Compute the similarity between image and text features
    similarity = torch.matmul(img_embeds, clf_embeds.T)
    
    # Get the class with the highest similarity score
    probs = similarity.softmax(dim=-1)
    best_class_idx = probs.argmax()
    
    return best_class_idx.to('cpu').numpy(), similarity#probs

def get_class_name_embed(class_names, processor, vl_model):
    """
    Generate embeddings for class names using a vision-language model.

    Args:
        class_names (list): List of class names to generate embeddings for.
        processor (transformers.AutoProcessor): Processor for text input processing.
        vl_model (transformers.CLIPModel): Vision-language model to get text embeddings.

    Returns:
        torch.Tensor: Tensor containing the embeddings for the class names.
    """
    clf_embeds = []
    for cl in class_names:
        with torch.no_grad():
            text_inputs = processor(text=[cl], padding=True, return_tensors="pt").to(vl_model.device)
            _clf_embeds = vl_model.get_text_features(**text_inputs).to('cpu').numpy()
            clf_embeds.append(_clf_embeds)
    return torch.tensor(np.concatenate(clf_embeds))

class OrthogDebiaser:
    """
    A class for debiasing embeddings using orthogonal projection, designed to reduce bias in multimodal models like CLIP.
    
    Args:
        wrap_clip_pipe (CLIPTextWrapper): Wrapper around CLIP model for text processing.
        biased_prompts (list): List of biased text prompts to extract biased embeddings.
        reg_classes (list): List of regularization class names (e.g., occupations).
        reg_format (list): Formatting templates for regularization prompts (e.g., 'male', 'female').
    """
    def __init__(self, wrap_clip_pipe, biased_prompts = ["an image of a man", "an image of a woman"], 
                 reg_classes = ["doctor", "lawyer"], reg_format = ["an image of a male {}.", "an image of a female {}"]):
        """
        Initialize OrthogDebiaser with prompts and class names for debiasing and regularization.
        """
        self.wrap_clip_pipe = wrap_clip_pipe
        self.biased_prompts = biased_prompts
        self.reg_classes = reg_classes
        self.reg_format = reg_format

    def get_biased_embeds(self, pipe_location = "joint"):
        """
        Get embeddings for biased prompts based on the specified stage of the model pipeline.
        
        Args:
            pipe_location (str): The model pipeline stage to extract embeddings from ('joint' or 'pre_projection').
        
        Returns:
            torch.Tensor: The biased embeddings extracted from the model.
        """
        #pipe-location should be joint or pre_projection
        if pipe_location == 'joint':
            bias_embeds = self.wrap_clip_pipe.get_joint_embed(self.biased_prompts, normalize=True)
        else:
            bias_embeds = self.wrap_clip_pipe.get_pre_projection_embed(self.biased_prompts, normalize=True)

        return bias_embeds
    
    def get_reg_embeds(self, pipe_location = "joint"):
        """
        Get embeddings for regularization prompts (e.g., occupations) based on the pipeline stage.
        
        Args:
            pipe_location (str): The model pipeline stage to extract embeddings from ('joint' or 'pre_projection').
        
        Returns:
            torch.Tensor: Regularization embeddings.
        """
        #pipe-location should be joint or pre_projection
        if pipe_location == 'joint':
            bias_embeds = self.wrap_clip_pipe.get_joint_embed(self.reg_prompt, normalize=True)

        return bias_embeds

    def get_A(self, z_i, z_j):
        """
        Compute the matrix A for orthogonal projection based on embedding vectors.
        
        Args:
            z_i (numpy.ndarray): First embedding vector.
            z_j (numpy.ndarray): Second embedding vector.
        
        Returns:
            numpy.ndarray: Computed A matrix.
        """
        z_i = z_i[:, None]
        z_j = z_j[:, None]
        return (np.matmul(z_i, z_i.T) + np.matmul(z_j, z_j.T) - np.matmul(z_i, z_j.T) - np.matmul(z_j, z_i.T))


    def get_M(self, embeddings, S):
        """
        Compute matrix M for debiasing using the regularization embeddings.
        
        Args:
            embeddings (numpy.ndarray): Embeddings to compute M from.
            S (list): Set of indices indicating the pairs of embeddings to compare.
        
        Returns:
            numpy.ndarray: Computed M matrix.
        """
        d = embeddings.shape[1]
        M = np.zeros((d, d))
        for s in S:
            M  += self.get_A(embeddings[s[0]], embeddings[s[1]])
        return M / len(S)

    def init_S_and_candidate_prompts(self):
        """
        Initialize the set of class pairs (S) and the candidate prompts for regularization.
        """
        self.reg_prompt = []
        self.S = []
        counter = 0
        for train_cls_i in self.reg_classes:
            train_cls_i = train_cls_i.lower()
            self.reg_prompt += [self.reg_format[0].format(train_cls_i), self.reg_format[1].format(train_cls_i)]#['A photo of a male {}.'.format(train_cls_i), 'A photo of a female {}.'.format(train_cls_i)]
            self.S += [[counter, counter + 1]]
            counter += 2

    def get_orth_proj_matrix(self, use_reg = True, return_proj = False, pipe_location = "joint", lam=500):
        """
        Compute the orthogonal projection matrix for debiasing embeddings.
        
        Args:
            use_reg (bool): Whether to apply regularization embeddings in the projection.
            pipe_location (str): The model pipeline stage to use ('joint' or 'pre_projection').
            lam (float): Regularization weight for controlling the influence of M matrix.
        
        Returns:
            torch.Tensor: The orthogonal projection matrix.
        """
        #pipe-location should be joint or pre_projection
        bias_embeds = self.get_biased_embeds(pipe_location = pipe_location).to('cpu').numpy()
        
        # orthogonal projection
        tSVD = TruncatedSVD(n_components=len(bias_embeds))
        embeddings_ = tSVD.fit_transform(bias_embeds)
        basis = tSVD.components_.T
        proj = np.linalg.inv(np.matmul(basis.T, basis))
        proj = np.matmul(basis, proj)
        proj = np.matmul(proj, basis.T)
        P0 = np.eye(proj.shape[0]) - proj

        if use_reg:
            self.init_S_and_candidate_prompts()
            reg_embeds = self.get_reg_embeds(pipe_location = pipe_location).to('cpu').numpy()
            M = self.get_M(reg_embeds, self.S)
            G = lam * M + np.eye(M.shape[0])
            P0_reg = np.matmul(P0, np.linalg.inv(G))
            return torch.tensor(P0_reg)
        
        if return_proj:
            return torch.tensor(P0), torch.tensor(proj), torch.tensor(basis)
        else:
            return torch.tensor(P0)
        
    def apply_non_proj_debiasing(self, embed, p0, proj_matrix, gamma = 0.1):
        bias_component = torch.matmul(embed, proj_matrix.T.float())
        
        bias_component = bias_component.to('cpu').numpy()
        # Step 1: Compute the norm of each vector (row) in V
        norms = np.linalg.norm(bias_component, axis=1)

        # Step 2: Compute the average norm
        avg_norm = np.mean(norms)

        # Step 3: Rescale each vector to have the average norm
        # Avoid division by zero for vectors with zero norm
        bias_component_rescaled = bias_component * (avg_norm / norms[:, np.newaxis])

        proj_debiased_embed = torch.matmul(embed, p0.T.float())
        bias_component_rescaled = torch.tensor(bias_component_rescaled)
        non_proj_debiased_embed = proj_debiased_embed + gamma * bias_component_rescaled
        return non_proj_debiased_embed

        


    def apply_debiasing_matrix(self, embed, p0):
        """
        Apply the debiasing matrix to the input embeddings.
        
        Args:
            embed (torch.Tensor): Embeddings to debias.
            p0 (torch.Tensor): Orthogonal projection matrix.
        
        Returns:
            torch.Tensor: Debiased embeddings.
        """
        debiased_embed = torch.matmul(embed, p0.T.float())
        return debiased_embed
    

class CLIPTextWrapper:
    """
    A wrapper class for CLIP model's text functionality, including embedding extraction and projection.

    Args:
        model_ID (str): Identifier for the pre-trained CLIP model.
        device (str): Device to run the model on, default is 'cuda'.
    """
    def __init__(self, model_ID, device='cuda'):
        """
        Initialize the CLIPTextWrapper by loading the CLIP model, processor, and tokenizer.
        """
        self.device = device
        self.model_ID = model_ID
        self.clip_model = CLIPModel.from_pretrained(self.model_ID).to(self.device)
        # self.clip_processor = AutoProcessor.from_pretrained(self.model_ID)
        self.clip_processor = CLIPProcessor.from_pretrained(self.model_ID)
        self.tokenizer = CLIPTokenizer.from_pretrained(self.model_ID)
        self.text_encoder = CLIPTextModel.from_pretrained(self.model_ID).to(self.device)
        self.inverse_projection_layer_martrix = torch.inverse(self.clip_model.text_projection.weight.detach())

    def get_joint_embed(self, input_text : list, normalize=True):
        """
        Extract joint (final-stage) embeddings for a list of text inputs using the CLIP model.
        
        Args:
            input_text (list): List of input text strings.
            normalize (bool): Whether to normalize the embeddings.
        
        Returns:
            torch.Tensor: Extracted text embeddings.
        """
        txt_embeds = []
        for txt in input_text:
            with torch.no_grad():
                text_inputs = self.clip_processor(text=[txt], padding=True, return_tensors="pt").to(self.device)
                _txt_embeds = self.clip_model.get_text_features(**text_inputs).to('cpu').numpy()
                txt_embeds.append(_txt_embeds)
        query_text_embedding = torch.tensor(np.concatenate(txt_embeds))
        if normalize:
            query_text_embedding /= query_text_embedding.norm(dim=-1, keepdim=True)
        return query_text_embedding


    def get_joint_image_embed(self, input_img : list, normalize=True):
        """
        Extract joint (final-stage) embeddings for a list of text inputs using the CLIP model.
        
        Args:
            input_img (list): List of input images.
            normalize (bool): Whether to normalize the embeddings.
        
        Returns:
            torch.Tensor: Extracted text embeddings.
        """
        im_embeds = []
        for im in input_img:
            with torch.no_grad():
                im_inputs = self.clip_processor(images=im, padding=True, return_tensors="pt").to(self.device)
                _im_embeds = self.clip_model.get_image_features(**im_inputs).to('cpu').numpy()
                im_embeds.append(_im_embeds)
        im_embedding = torch.tensor(np.concatenate(im_embeds))
        if normalize:
            im_embedding /= im_embedding.norm(dim=-1, keepdim=True)
        return im_embedding
    
    def get_unpooled_embed(self, input_text : list):
        """
        Extract pre-projection embeddings from the CLIP model by reversing the projection layer.

        Args:
            input_text (list): List of input text strings.
            normalize (bool): Whether to normalize the embeddings.

        Returns:
            torch.Tensor: Pre-projection embeddings.
        """
        with torch.no_grad():
            _input = self.tokenizer(input_text, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
            _embeddings = self.text_encoder(_input.input_ids.to(self.device), output_hidden_states=True)[0]
            return _embeddings

    def get_pre_projection_embed(self, input_text : list, normalize=True):
        """
        Extract pre-projection embeddings from the CLIP model by reversing the projection layer.

        Args:
            input_text (list): List of input text strings.
            normalize (bool): Whether to normalize the embeddings.

        Returns:
            torch.Tensor: Pre-projection embeddings.
        """
        proj_embed = self.get_joint_embed(input_text = input_text, normalize=False)
        pre_proj_embed = torch.matmul( proj_embed.to(self.device),  self.inverse_projection_layer_martrix.T).to('cpu')
        if normalize:
            pre_proj_embed /= pre_proj_embed.norm(dim=-1, keepdim=True)
        return pre_proj_embed
    
    def map_to_joint(self, pre_proj_embed):
        """
        Map pre-projection embeddings to the joint (final) space using the CLIP text projection layer.
        
        Args:
            pre_proj_embed (torch.Tensor): Pre-projection embeddings.
        
        Returns:
            torch.Tensor: Joint-space embeddings.
        """
        with torch.no_grad():
            proj_embed = self.clip_model.text_projection(pre_proj_embed.float().to(self.device)).to('cpu')
        return proj_embed

    def map_from_joint_to_pre_projection(self, joint_embed):
        """
        Map joint (final-stage) embeddings back to the pre-projection space.

        Args:
            joint_embed (torch.Tensor): Joint-space embeddings.

        Returns:
            torch.Tensor: Pre-projection embeddings.
        """
        pre_proj_embed = torch.matmul(joint_embed.to(self.device),  self.inverse_projection_layer_martrix.T).to('cpu')
        return pre_proj_embed