import random
import numpy as np
import torch
from base_processor import BaseProcessor

class SALProcessor(BaseProcessor):

    # def process_data(self, data, usermode, device):
    #     self.data = data
    #     self.usermode = usermode
    #     self.vecs = {}
        
    #     for demo in self.data:
                        
    #         # Print information about this demographic attribute
    #         print(f"\n===== Distribution for {demo} =====")
            
    #         categories = list(self.data[demo].keys())
    #         print(f"Categories: {categories}")
            
    #         # Print sample counts
    #         for cat in categories:
    #             sample_count = len(self.data[demo][cat])
    #             print(f"Category '{cat}' has {sample_count} samples")


    #         self.vecs[demo] = {}
    #         # Store normalized difference vectors between categories
    #         categories = list(self.data[demo].keys())
            
    #         # Make sure data is on the correct device
    #         cat1_data = self.data[demo][categories[0]].to(device)
    #         cat2_data = self.data[demo][categories[1]].to(device)
            
    #         cat1_mean = cat1_data.mean(axis=0)
    #         cat2_mean = cat2_data.mean(axis=0)
    #         diff = cat2_mean - cat1_mean
    #         normalized_diff = diff / torch.norm(diff)
            
    #         # Store for both categories
    #         self.vecs[demo][categories[0]] = normalized_diff
    #         self.vecs[demo][categories[1]] = normalized_diff

    #         if "enoise" in self.usermode:
    #             if not hasattr(self, "edist"):
    #                 self.edist = {}
    #             self.edist[demo] = {}
    #             for cat in self.data[demo]:
    #                 # Move the data to device first, then perform matmul
    #                 cat_data = self.data[demo][cat].to(device)
    #                 self.edist[demo][cat] = cat_data.matmul(normalized_diff)
    #                 print(f"cat_data shape: {cat_data.shape}, normalized_diff shape: {normalized_diff.shape}, self.edist[demo][cat].shape: {self.edist[demo][cat].shape}")


    # def transform(self, demo, X_test, proj):
    #     original_dtype = X_test.dtype
    #     original_device = X_test.device
    #     proj = proj.to(device=original_device, dtype=original_dtype)
        
    #     # Apply projection first
    #     print(f"X_test shape: {X_test.shape}")

    #     if len(X_test.shape) == 2:
    #         X = torch.matmul(proj, X_test.T).T
    #     elif len(X_test.shape) == 3:
    #         X = torch.matmul(proj, X_test.transpose(1, 2)).transpose(1, 2)
        
    #     # Add empirical noise using actual examples
    #     if "enoise" in self.usermode:
    #         scale = 1.0 if self.usermode.get("enoise") is None else float(self.usermode["enoise"])
            
    #         if not hasattr(self, "rand_group"):
    #             # Store group selection and noise consistently across calls
    #             categories = list(self.edist[demo].keys())
    #             self.rand_group = torch.randint(0, len(categories), (X_test.shape[0],))
                
    #             noise = []
    #             for i in range(X_test.shape[0]):
    #                 cat = categories[self.rand_group[i].item()]
    #                 cat_projections = self.edist[demo][cat]
    #                 idx = torch.randint(0, len(cat_projections), (1,))
    #                 print(f"cat_projections shape: {cat_projections.shape}, idx: {idx}, cat_projections[idx]: {cat_projections[idx]}")
    #                 # Extract scalar value properly
    #                 noise_val = cat_projections[idx].item()
    #                 noise.append(torch.tensor([noise_val], device=original_device, dtype=original_dtype))
                
    #             self.noise = torch.cat(noise)
            
    #         # Get directions for each example's group
    #         categories = list(self.vecs[demo].keys())
    #         directions = torch.stack([self.vecs[demo][categories[g.item()]] for g in self.rand_group])
    #         directions = directions.to(device=original_device, dtype=original_dtype)
            
    #         # Apply noise consistently
    #         delta = scale * self.noise.unsqueeze(-1) * directions
            
    #         if len(X.shape) == 2:
    #             X = X + delta
    #         elif len(X.shape) == 3:
    #             X = X + delta.unsqueeze(1)
        
    #     return X


    def process_data(self, data, usermode, device):
        try:
            self.data = data
            self.usermode = usermode
            self.vecs = {}
            
            for demo in self.data:
                try:
                    # Print information about this demographic attribute
                    print(f"\n===== Distribution for {demo} =====")
                    
                    categories = list(self.data[demo].keys())
                    print(f"Categories: {categories}")
                    
                    # Check if we have enough categories
                    if len(categories) < 2:
                        print(f"Warning: Need at least 2 categories for {demo}, found {len(categories)}")
                        continue
                    
                    # Print sample counts
                    for cat in categories:
                        try:
                            sample_count = len(self.data[demo][cat])
                            print(f"Category '{cat}' has {sample_count} samples")
                        except Exception as e:
                            print(f"Error getting sample count for {cat}: {e}")
                    
                    self.vecs[demo] = {}
                    
                    # Safely compute category means and difference vector
                    try:
                        # Make sure data is on the correct device
                        cat1_data = self.data[demo][categories[0]].to(device)
                        cat2_data = self.data[demo][categories[1]].to(device)
                        
                        # Check for empty data
                        if cat1_data.shape[0] == 0 or cat2_data.shape[0] == 0:
                            print(f"Warning: Empty data for one of the categories in {demo}")
                            continue
                        
                        cat1_mean = cat1_data.mean(axis=0)
                        cat2_mean = cat2_data.mean(axis=0)
                        diff = cat2_mean - cat1_mean
                        
                        # Avoid division by zero
                        norm = torch.norm(diff)
                        if norm > 1e-8:
                            normalized_diff = diff / norm
                        else:
                            print(f"Warning: Difference vector has near-zero norm for {demo}")
                            # Create a random unit vector as fallback
                            normalized_diff = torch.randn_like(diff)
                            normalized_diff = normalized_diff / torch.norm(normalized_diff)
                        
                        # Store for both categories
                        self.vecs[demo][categories[0]] = normalized_diff
                        self.vecs[demo][categories[1]] = normalized_diff
                        
                        # Compute and store projections for noise if needed
                        if "enoise" in self.usermode:
                            if not hasattr(self, "edist"):
                                self.edist = {}
                            
                            if demo not in self.edist:
                                self.edist[demo] = {}
                            
                            for cat in self.data[demo]:
                                try:
                                    # Move the data to device first, then perform matmul
                                    cat_data = self.data[demo][cat].to(device)
                                    
                                    # Compute projections onto the bias direction
                                    self.edist[demo][cat] = cat_data.matmul(normalized_diff)
                                    
                                    # Print projection stats for debugging
                                    proj = self.edist[demo][cat]
                                    print(f"Category '{cat}' projections - shape: {proj.shape}, " +
                                        f"min: {proj.min().item():.4f}, max: {proj.max().item():.4f}, " +
                                        f"mean: {proj.mean().item():.4f}, std: {proj.std().item():.4f}")
                                except Exception as e:
                                    print(f"Error processing projections for {cat}: {e}")
                    except Exception as e:
                        print(f"Error computing direction vector for {demo}: {e}")
                except Exception as e:
                    print(f"Error processing demographic {demo}: {e}")
        except Exception as e:
            print(f"Error in process_data: {e}")
            # Initialize empty structures to avoid NoneType errors
            if not hasattr(self, "vecs"):
                self.vecs = {}
            if "enoise" in usermode and not hasattr(self, "edist"):
                self.edist = {}

    def transform(self, demo, X_test, proj, clamping=True):
        original_dtype = X_test.dtype
        original_device = X_test.device
        proj = proj.to(device=original_device, dtype=original_dtype)
        
        if "mean_centering" in self.usermode:
            print(f"\n\nX_test is {X_test}, X_test shape is {X_test.shape}\n\n\n")

            if len(X_test.shape) == 2:
                mu = X_test.mean(dim=0, keepdim=True) # Shape: (1, feature_dim)
                
                X_test = X_test - mu
                X_test = torch.matmul(proj, X_test.T).T

            elif len(X_test.shape) == 3:
                # 3D case: X shape is (num_contexts, batch_size, feature_dim)
                # Calculate mean across the batch dimension (axis=1)
                mu = X_test.mean(axis=1, keepdims=True)  # Shape: (num_contexts, 1, feature_dim)
                
                # Apply transformation to each context
                X_test = X_test - mu
                
                X_test = torch.matmul(proj, X_test.transpose(1, 2)).transpose(1, 2)

            X = X_test + mu
            print(f"\n\n\nBefore clamping, Maximum value in X: {X.max().item()}, Minimum value in X: {X.min().item()}\n\n\n")
            if clamping:
                X = X.clamp(-1, 1)  # Clamping to avoid extreme values
                print(f"After clamping, Maximum value in X: {X.max().item()}, Minimum value in X: {X.min().item()}\n\n\n")
            else:
                print(f"Clamping is disabled, Maximum value in X: {X.max().item()}, Minimum value in X: {X.min().item()}\n\n\n")

        else:
            # Apply projection first
            if len(X_test.shape) == 2:
                X = torch.matmul(proj, X_test.T).T
            elif len(X_test.shape) == 3:
                X = torch.matmul(proj, X_test.transpose(1, 2)).transpose(1, 2)
            print(f"X is {X}")
            # Print maximum and minimum values in X
            print(f"\n\n\nBefore clamping, Maximum value in X: {X.max().item()}, Minimum value in X: {X.min().item()}")
            if clamping:
                X = X.clamp(-1, 1)  # Clamping to avoid extreme values
                print(f"After clamping, Maximum value in X: {X.max().item()}, Minimum value in X: {X.min().item()}\n\n\n")
            else:
                print(f"Clamping is disabled, Maximum value in X: {X.max().item()}, Minimum value in X: {X.min().item()}\n\n\n")
        
        # Add empirical noise using actual examples
        if "enoise" in self.usermode:
            # Get scale factor, properly handling 0.0
            scale = float(self.usermode.get("enoise", 1.0))
            
            # Only compute noise if we haven't already
            if not hasattr(self, "rand_group") or not hasattr(self, "noise"):
                # Safely get categories
                categories = list(self.edist[demo].keys())
                if len(categories) == 0:
                    print("Warning: No categories found for noise computation")
                    return X
                    
                # Create random group assignments
                try:
                    self.rand_group = torch.randint(0, len(categories), (X_test.shape[0],))
                    
                    # Collect noise values safely
                    noise_values = []
                    for i in range(X_test.shape[0]):
                        try:
                            cat = categories[self.rand_group[i].item()]
                            cat_projections = self.edist[demo][cat]
                            
                            if len(cat_projections) == 0:
                                print(f"Warning: Empty projections for category {cat}")
                                noise_values.append(0.0)  # Default to zero
                                continue
                                
                            idx = torch.randint(0, len(cat_projections), (1,))
                            noise_val = cat_projections[idx].item()
                            noise_values.append(noise_val)
                        except Exception as e:
                            print(f"Error processing noise for example {i}: {e}")
                            noise_values.append(0.0)  # Default to zero
                    
                    # Convert to tensor safely
                    try:
                        self.noise = torch.tensor(noise_values, device=original_device, dtype=original_dtype)
                    except Exception as e:
                        print(f"Error creating noise tensor: {e}")
                        return X  # Return without adding noise
                except Exception as e:
                    print(f"Error in noise computation: {e}")
                    return X  # Return without adding noise
            
            try:
                # Get directions for each example's group
                categories = list(self.vecs[demo].keys())
                directions = []
                
                for g in self.rand_group:
                    cat_idx = g.item() if g.numel() == 1 else 0
                    if cat_idx < len(categories):
                        directions.append(self.vecs[demo][categories[cat_idx]])
                    else:
                        # Default to first category if index out of bounds
                        directions.append(self.vecs[demo][categories[0]])
                
                directions = torch.stack(directions).to(device=original_device, dtype=original_dtype)
                
                # Apply noise consistently with error handling
                try:
                    # Reshape noise for broadcasting
                    noise_reshaped = self.noise.view(-1, 1)
                    
                    # Create delta (noise vectors)
                    delta = scale * noise_reshaped * directions
                    
                    # Add noise based on input dimensions
                    if len(X.shape) == 2:
                        X = X + delta
                    elif len(X.shape) == 3:
                        # For 3D tensors, add noise to all sequence positions
                        delta = delta.unsqueeze(1)
                        X = X + delta
                except Exception as e:
                    print(f"Error applying noise to data: {e}")
            except Exception as e:
                print(f"Error preparing directions: {e}")
        
        return X
    
    def modify_embedding(self, pipe, prompt_embeds, pooled_prompt_embeds, usermode={}, exp_dir="."):
        # Load data if needed
        if not hasattr(self, "vecs"):
            data = torch.load(f"{exp_dir}/extracted_features.pt")
            self.process_data(data, usermode, pooled_prompt_embeds.device)
        
        self.dim = pooled_prompt_embeds.shape[1]
        print(f"self.dim is {self.dim}")
        self.usermode = usermode
        
        # Split embeddings
        prompt_embeds1, prompt_embeds2 = prompt_embeds[:, :, :self.dim], prompt_embeds[:, :, self.dim:]

        target_device = pooled_prompt_embeds.device
        
        # Get data for first protected attribute
        protect = list(self.data.keys())[0]
        attribute_data = self.data[protect]
        keys = list(attribute_data.keys())
        
        # Ensure data is on correct device
        data_tensors = [tensor.to(target_device) for tensor in attribute_data.values()]
        X = torch.cat(data_tensors, dim=0)
        
        # Create one-hot encoding for protected attributes
        if len(keys) == 2:
            print(f"Creating Z matrix for two categories: {keys[0]} and {keys[1]}")
            # Create Z matrix for two categories
            Z = torch.stack([
                torch.cat((torch.zeros(len(attribute_data[keys[0]])), torch.ones(len(attribute_data[keys[1]])))),
                torch.cat((torch.ones(len(attribute_data[keys[0]])), torch.zeros(len(attribute_data[keys[1]]))))
            ]).to(dtype=X.dtype, device=target_device)
            Z = Z.T

            # Print Z matrix information
            print(f"Z matrix shape: {Z.shape}")
            print("Z matrix (first 10 rows):")
            print(Z[:10])
            
            # Count values to verify correct creation
            z0_count = (Z[:, 0] == 1).sum().item()
            z1_count = (Z[:, 1] == 1).sum().item()
            print(f"Count of examples with Z[0]=1: {z0_count} (should match '{keys[0]}' count)")
            print(f"Count of examples with Z[1]=1: {z1_count} (should match '{keys[1]}' count)")
        elif len(keys) == 3:
            print(f"Creating Z matrix for three categories: {keys[0]}, {keys[1]}, and {keys[2]}")
            # Create Z matrix for three categories
            Z = torch.stack([
                torch.cat((torch.ones(len(attribute_data[keys[0]])), torch.zeros(len(attribute_data[keys[1]])), torch.zeros(len(attribute_data[keys[2]])))),
                torch.cat((torch.zeros(len(attribute_data[keys[0]])), torch.ones(len(attribute_data[keys[1]])), torch.zeros(len(attribute_data[keys[2]])))),
                torch.cat((torch.zeros(len(attribute_data[keys[0]])), torch.zeros(len(attribute_data[keys[1]])), torch.ones(len(attribute_data[keys[2]]))))
            ]).to(dtype=X.dtype, device=target_device)
            Z = Z.T

            # Print Z matrix information
            print(f"Z matrix shape: {Z.shape}")
            print("Z matrix (first 10 rows):")
            print(Z[:10])
            
            # Count values to verify correct creation
            z0_count = (Z[:, 0] == 1).sum().item()
            z1_count = (Z[:, 1] == 1).sum().item()
            z2_count = (Z[:, 2] == 1).sum().item()
            print(f"Count of examples with Z[0]=1: {z0_count} (should match '{keys[0]}' count)")
            print(f"Count of examples with Z[1]=1: {z1_count} (should match '{keys[1]}' count)")
            print(f"Count of examples with Z[2]=1: {z2_count} (should match '{keys[2]}' count)")
        else:
            raise ValueError("Only 2 or 3 categories are supported for protected attributes.")


        # A = np.dot(X.T, Z) / X.shape[0]
        # u, s, vh = np.linalg.svd(A, full_matrices=True)

        # Compute alignment matrix
        A = torch.mm(X.T, Z) / X.shape[0]
        
        # Compute SVD
        u, s, vh = torch.linalg.svd(A, full_matrices=True)

        print(f"u shape: {u.shape}, s shape: {s.shape}, vh shape: {vh.shape}, X shape: {X.shape}, Z shape: {Z.shape}, A shape: {A.shape}")
        
        # Get projection matrix by removing first columns

        removal = 2 if len(keys) == 2 else 3
        print(f"removal is {removal}")

        u_r = u[:, removal:]
        proj = torch.mm(u_r, u_r.T).to(target_device)

        print(f"proj shape after removal: {proj.shape}, proj is {proj}")
        
        # Apply transform with noise
        for demo in usermode["protect"]:
            pooled_prompt_embeds = self.transform(demo, pooled_prompt_embeds, proj)
            prompt_embeds1 = self.transform(demo, prompt_embeds1, proj)
            prompt_embeds2 = self.transform(demo, prompt_embeds2, proj)
        
        # Recombine embeddings
        prompt_embeds = torch.cat((prompt_embeds1, prompt_embeds2), dim=-1)

        return prompt_embeds, pooled_prompt_embeds