import random
import numpy as np
import torch
from base_processor import BaseProcessor
from utils import mutual_information_2d

class ClipCocoProcessor(BaseProcessor):

    # def process_data(self, data, usermode, device):
    #     self.data = data
    #     self.usermode = usermode
    #     self.vecs = {}
        
    #     for demo in self.data:
                        
    #         # Print information about this demographic attribute
    #         print(f"\n===== Distribution for {demo} =====")
            
    #         categories = list(self.data[demo].keys())
    #         print(f"Categories: {categories}")
            
    #         # Print sample counts
    #         for cat in categories:
    #             sample_count = len(self.data[demo][cat])
    #             print(f"Category '{cat}' has {sample_count} samples")


    #         self.vecs[demo] = {}
    #         # Store normalized difference vectors between categories
    #         categories = list(self.data[demo].keys())
            
    #         # Make sure data is on the correct device
    #         cat1_data = self.data[demo][categories[0]].to(device)
    #         cat2_data = self.data[demo][categories[1]].to(device)
            
    #         cat1_mean = cat1_data.mean(axis=0)
    #         cat2_mean = cat2_data.mean(axis=0)
    #         diff = cat2_mean - cat1_mean
    #         normalized_diff = diff / torch.norm(diff)
            
    #         # Store for both categories
    #         self.vecs[demo][categories[0]] = normalized_diff
    #         self.vecs[demo][categories[1]] = normalized_diff

    #         if "enoise" in self.usermode:
    #             if not hasattr(self, "edist"):
    #                 self.edist = {}
    #             self.edist[demo] = {}
    #             for cat in self.data[demo]:
    #                 # Move the data to device first, then perform matmul
    #                 cat_data = self.data[demo][cat].to(device)
    #                 self.edist[demo][cat] = cat_data.matmul(normalized_diff)
    #                 print(f"cat_data shape: {cat_data.shape}, normalized_diff shape: {normalized_diff.shape}, self.edist[demo][cat].shape: {self.edist[demo][cat].shape}")


    # def transform(self, demo, X_test, proj):
    #     original_dtype = X_test.dtype
    #     original_device = X_test.device
    #     proj = proj.to(device=original_device, dtype=original_dtype)
        
    #     # Apply projection first
    #     print(f"X_test shape: {X_test.shape}")

    #     if len(X_test.shape) == 2:
    #         X = torch.matmul(proj, X_test.T).T
    #     elif len(X_test.shape) == 3:
    #         X = torch.matmul(proj, X_test.transpose(1, 2)).transpose(1, 2)
        
    #     # Add empirical noise using actual examples
    #     if "enoise" in self.usermode:
    #         scale = 1.0 if self.usermode.get("enoise") is None else float(self.usermode["enoise"])
            
    #         if not hasattr(self, "rand_group"):
    #             # Store group selection and noise consistently across calls
    #             categories = list(self.edist[demo].keys())
    #             self.rand_group = torch.randint(0, len(categories), (X_test.shape[0],))
                
    #             noise = []
    #             for i in range(X_test.shape[0]):
    #                 cat = categories[self.rand_group[i].item()]
    #                 cat_projections = self.edist[demo][cat]
    #                 idx = torch.randint(0, len(cat_projections), (1,))
    #                 print(f"cat_projections shape: {cat_projections.shape}, idx: {idx}, cat_projections[idx]: {cat_projections[idx]}")
    #                 # Extract scalar value properly
    #                 noise_val = cat_projections[idx].item()
    #                 noise.append(torch.tensor([noise_val], device=original_device, dtype=original_dtype))
                
    #             self.noise = torch.cat(noise)
            
    #         # Get directions for each example's group
    #         categories = list(self.vecs[demo].keys())
    #         directions = torch.stack([self.vecs[demo][categories[g.item()]] for g in self.rand_group])
    #         directions = directions.to(device=original_device, dtype=original_dtype)
            
    #         # Apply noise consistently
    #         delta = scale * self.noise.unsqueeze(-1) * directions
            
    #         if len(X.shape) == 2:
    #             X = X + delta
    #         elif len(X.shape) == 3:
    #             X = X + delta.unsqueeze(1)
        
    #     return X


    def process_data(self, data, usermode, device):
        try:
            self.data = data
            self.usermode = usermode
            self.vecs = {}
            
            for demo in self.data:
                try:
                    # Print information about this demographic attribute
                    print(f"\n===== Distribution for {demo} =====")
                    
                    categories = list(self.data[demo].keys())
                    print(f"Categories: {categories}")
                    
                    # Check if we have enough categories
                    if len(categories) < 2:
                        print(f"Warning: Need at least 2 categories for {demo}, found {len(categories)}")
                        continue
                    
                    # Print sample counts
                    for cat in categories:
                        try:
                            sample_count = len(self.data[demo][cat])
                            print(f"Category '{cat}' has {sample_count} samples")
                        except Exception as e:
                            print(f"Error getting sample count for {cat}: {e}")
                    
                    self.vecs[demo] = {}
                    
                    # Safely compute category means and difference vector
                    try:
                        # Make sure data is on the correct device
                        cat1_data = self.data[demo][categories[0]].to(device)
                        cat2_data = self.data[demo][categories[1]].to(device)
                        
                        # Check for empty data
                        if cat1_data.shape[0] == 0 or cat2_data.shape[0] == 0:
                            print(f"Warning: Empty data for one of the categories in {demo}")
                            continue
                        
                        cat1_mean = cat1_data.mean(axis=0)
                        cat2_mean = cat2_data.mean(axis=0)
                        diff = cat2_mean - cat1_mean
                        
                        # Avoid division by zero
                        norm = torch.norm(diff)
                        if norm > 1e-8:
                            normalized_diff = diff / norm
                        else:
                            print(f"Warning: Difference vector has near-zero norm for {demo}")
                            # Create a random unit vector as fallback
                            normalized_diff = torch.randn_like(diff)
                            normalized_diff = normalized_diff / torch.norm(normalized_diff)
                        
                        # Store for both categories
                        self.vecs[demo][categories[0]] = normalized_diff
                        self.vecs[demo][categories[1]] = normalized_diff
                        
                        # Compute and store projections for noise if needed
                        if "enoise" in self.usermode:
                            if not hasattr(self, "edist"):
                                self.edist = {}
                            
                            if demo not in self.edist:
                                self.edist[demo] = {}
                            
                            for cat in self.data[demo]:
                                try:
                                    # Move the data to device first, then perform matmul
                                    cat_data = self.data[demo][cat].to(device)
                                    
                                    # Compute projections onto the bias direction
                                    self.edist[demo][cat] = cat_data.matmul(normalized_diff)
                                    
                                    # Print projection stats for debugging
                                    proj = self.edist[demo][cat]
                                    print(f"Category '{cat}' projections - shape: {proj.shape}, " +
                                        f"min: {proj.min().item():.4f}, max: {proj.max().item():.4f}, " +
                                        f"mean: {proj.mean().item():.4f}, std: {proj.std().item():.4f}")
                                except Exception as e:
                                    print(f"Error processing projections for {cat}: {e}")
                    except Exception as e:
                        print(f"Error computing direction vector for {demo}: {e}")
                except Exception as e:
                    print(f"Error processing demographic {demo}: {e}")
        except Exception as e:
            print(f"Error in process_data: {e}")
            # Initialize empty structures to avoid NoneType errors
            if not hasattr(self, "vecs"):
                self.vecs = {}
            if "enoise" in usermode and not hasattr(self, "edist"):
                self.edist = {}

    # def transform(self, demo, X_test, proj):
    #     original_dtype = X_test.dtype
    #     original_device = X_test.device
    #     proj = proj.to(device=original_device, dtype=original_dtype)
        
    #     # Apply projection first
    #     if len(X_test.shape) == 2:
    #         X = torch.matmul(proj, X_test.T).T
    #     elif len(X_test.shape) == 3:
    #         X = torch.matmul(proj, X_test.transpose(1, 2)).transpose(1, 2)
        
    #     # Add empirical noise using actual examples
    #     if "enoise" in self.usermode:
    #         # Get scale factor, properly handling 0.0
    #         scale = float(self.usermode.get("enoise", 1.0))
            
    #         # Only compute noise if we haven't already
    #         if not hasattr(self, "rand_group") or not hasattr(self, "noise"):
    #             # Safely get categories
    #             categories = list(self.edist[demo].keys())
    #             if len(categories) == 0:
    #                 print("Warning: No categories found for noise computation")
    #                 return X
                    
    #             # Create random group assignments
    #             try:
    #                 self.rand_group = torch.randint(0, len(categories), (X_test.shape[0],))
                    
    #                 # Collect noise values safely
    #                 noise_values = []
    #                 for i in range(X_test.shape[0]):
    #                     try:
    #                         cat = categories[self.rand_group[i].item()]
    #                         cat_projections = self.edist[demo][cat]
                            
    #                         if len(cat_projections) == 0:
    #                             print(f"Warning: Empty projections for category {cat}")
    #                             noise_values.append(0.0)  # Default to zero
    #                             continue
                                
    #                         idx = torch.randint(0, len(cat_projections), (1,))
    #                         noise_val = cat_projections[idx].item()
    #                         noise_values.append(noise_val)
    #                     except Exception as e:
    #                         print(f"Error processing noise for example {i}: {e}")
    #                         noise_values.append(0.0)  # Default to zero
                    
    #                 # Convert to tensor safely
    #                 try:
    #                     self.noise = torch.tensor(noise_values, device=original_device, dtype=original_dtype)
    #                 except Exception as e:
    #                     print(f"Error creating noise tensor: {e}")
    #                     return X  # Return without adding noise
    #             except Exception as e:
    #                 print(f"Error in noise computation: {e}")
    #                 return X  # Return without adding noise
            
    #         try:
    #             # Get directions for each example's group
    #             categories = list(self.vecs[demo].keys())
    #             directions = []
                
    #             for g in self.rand_group:
    #                 cat_idx = g.item() if g.numel() == 1 else 0
    #                 if cat_idx < len(categories):
    #                     directions.append(self.vecs[demo][categories[cat_idx]])
    #                 else:
    #                     # Default to first category if index out of bounds
    #                     directions.append(self.vecs[demo][categories[0]])
                
    #             directions = torch.stack(directions).to(device=original_device, dtype=original_dtype)
                
    #             # Apply noise consistently with error handling
    #             try:
    #                 # Reshape noise for broadcasting
    #                 noise_reshaped = self.noise.view(-1, 1)
                    
    #                 # Create delta (noise vectors)
    #                 delta = scale * noise_reshaped * directions
                    
    #                 # Add noise based on input dimensions
    #                 if len(X.shape) == 2:
    #                     X = X + delta
    #                 elif len(X.shape) == 3:
    #                     # For 3D tensors, add noise to all sequence positions
    #                     delta = delta.unsqueeze(1)
    #                     X = X + delta
    #             except Exception as e:
    #                 print(f"Error applying noise to data: {e}")
    #         except Exception as e:
    #             print(f"Error preparing directions: {e}")
        
    #     return X
    
    def modify_embedding(self, pipe, prompt_embeds, pooled_prompt_embeds, usermode={}, exp_dir="."):
        # Load data if needed
        if not hasattr(self, "vecs"):
            data = torch.load(f"{exp_dir}/extracted_features.pt")
            self.process_data(data, usermode, pooled_prompt_embeds.device)
        
        self.dim = pooled_prompt_embeds.shape[1]
        print(f"self.dim is {self.dim}")
        self.usermode = usermode
        
        # Split embeddings
        prompt_embeds1, prompt_embeds2 = prompt_embeds[:, :, :self.dim], prompt_embeds[:, :, self.dim:]

        target_device = pooled_prompt_embeds.device
        
        # Get data for first protected attribute
        protect = list(self.data.keys())[0]
        attribute_data = self.data[protect]
        keys = list(attribute_data.keys())
        
        # Ensure data is on correct device
        data_tensors = [tensor.to(target_device) for tensor in attribute_data.values()]
        X = torch.cat(data_tensors, dim=0)
        
        # Create one-hot encoding for protected attributes
        Z = torch.stack([
            torch.cat((torch.zeros(len(attribute_data[keys[0]])), torch.ones(len(attribute_data[keys[1]])))),
            torch.cat((torch.ones(len(attribute_data[keys[0]])), torch.zeros(len(attribute_data[keys[1]]))))
        ]).to(dtype=X.dtype, device=target_device)
        Z = Z.T


        # Print Z matrix information
        print(f"Z matrix shape: {Z.shape}")
        print("Z matrix (first 10 rows):")
        print(Z[:10])
        
        # Count values to verify correct creation
        z0_count = (Z[:, 0] == 1).sum().item()
        z1_count = (Z[:, 1] == 1).sum().item()
        print(f"Count of examples with Z[0]=1: {z0_count} (should match '{keys[0]}' count)")
        print(f"Count of examples with Z[1]=1: {z1_count} (should match '{keys[1]}' count)")

        # estimate mutula information
        mis = []
        for col in range(X.shape[1]):
            # print(f"shape of X[:,col]: {X[:,col].squeeze().shape}, shape of Z[:,0]: {Z[:,0].squeeze().shape}")
            mi = mutual_information_2d(X[:,col].squeeze().cpu().numpy(), Z[:,0].squeeze().cpu().numpy())
            mis.append((mi, col))
        mis = sorted(mis, reverse=False)
        mis = np.array([l[1] for l in mis])

        if "removal" in self.usermode:
            num_clip = int(self.usermode["removal"])
        else:
            num_clip = 2

        # estimate mutual information
        mis = []
        for col in range(X.shape[1]):
            mi = mutual_information_2d(X[:,col].squeeze().cpu().numpy(), Z[:,0].squeeze().cpu().numpy())
            mis.append((mi, col))
        mis = sorted(mis, reverse=False)  # Sort by mutual information (lowest to highest)

        # Extract only the indices - keep as tuples until after extraction
        high_mi_indices = [l[1] for l in mis[:num_clip]]  # Get the indices with lowest MI

        # Create a mask for dimensions to keep (1s) and zero out (0s)
        # mask = torch.ones(pooled_prompt_embeds.shape[1], device=pooled_prompt_embeds.device)
        # Convert mask to same precision as the tensors
        mask = torch.ones(pooled_prompt_embeds.shape[1], 
                        device=pooled_prompt_embeds.device,
                        dtype=pooled_prompt_embeds.dtype)  # Match dtype
        for idx in high_mi_indices:
            mask[idx] = 0.0  # Zero out dimensions with low gender correlation

        # Apply the mask - keeping proper tensor dimensions and device
        pooled_prompt_embeds = pooled_prompt_embeds * mask
        prompt_embeds1 = prompt_embeds1 * mask.unsqueeze(0).unsqueeze(0)
        prompt_embeds2 = prompt_embeds2 * mask.unsqueeze(0).unsqueeze(0)

        # Recombine embeddings
        prompt_embeds = torch.cat((prompt_embeds1, prompt_embeds2), dim=-1)
        

        return prompt_embeds, pooled_prompt_embeds





        # num_clip = 5

        # # estimate mutual information
        # mis = []
        # for col in range(X.shape[1]):
        #     mi = mutual_information_2d(X[:,col].squeeze().cpu().numpy(), Z[:,0].squeeze().cpu().numpy())
        #     mis.append((mi, col))
        # mis = sorted(mis, reverse=False)  # Sort by mutual information (lowest to highest)

        # # Extract only the indices - keep as tuples until after extraction
        # high_mi_indices = [l[1] for l in mis[:num_clip]]  # Get the indices with lowest MI

        # # Calculate average values for each dimension
        # pooled_avg_values = pooled_prompt_embeds.mean(dim=0)
        # prompt1_avg_values = prompt_embeds1.mean(dim=(0,1))
        # prompt2_avg_values = prompt_embeds2.mean(dim=(0,1))

        # # Create masks with same precision as the tensors
        # mask = torch.ones(pooled_prompt_embeds.shape[1],
        #                 device=pooled_prompt_embeds.device,
        #                 dtype=pooled_prompt_embeds.dtype)  # Match dtype
        # for idx in high_mi_indices:
        #     mask[idx] = 0.0  # Zero in mask means "use average instead"

        # # Apply masks to create blended embeddings
        # # Original * mask + average * (1-mask)
        # pooled_prompt_embeds = pooled_prompt_embeds * mask + pooled_avg_values * (1 - mask)

        # # For multi-dimensional tensors, expand the mask properly
        # mask_expanded1 = mask.unsqueeze(0).unsqueeze(0).to(dtype=prompt_embeds1.dtype)
        # prompt_embeds1 = prompt_embeds1 * mask_expanded1 + prompt1_avg_values.unsqueeze(0).unsqueeze(0) * (1 - mask_expanded1)

        # mask_expanded2 = mask.unsqueeze(0).unsqueeze(0).to(dtype=prompt_embeds2.dtype)
        # prompt_embeds2 = prompt_embeds2 * mask_expanded2 + prompt2_avg_values.unsqueeze(0).unsqueeze(0) * (1 - mask_expanded2)

        # # Recombine embeddings
        # prompt_embeds = torch.cat((prompt_embeds1, prompt_embeds2), dim=-1)


        # return prompt_embeds, pooled_prompt_embeds