import warnings
import torch
import torch.nn.functional as F
import time
import numpy as np

class DeltaDistributionTransducer():
    def __init__(self, args, cfg, memory_bank, smiles_bank):
        self.args = args
        self.cfg = cfg
        self.device = args.device
        self.num_candidates = cfg.transducer.num_candidates
        self.adaptive_k_min = cfg.transducer.adaptive_mask.k_min
        self.adaptive_k_max = cfg.transducer.adaptive_mask.k_max
        self.anchor_metric = cfg.transducer.anchor_metric
        self.sampling_strategy = cfg.transducer.sampling_strategy
        #self.temperature = cfg.transducer.temparature.value
        
        #self.smiles_bank = smiles_bank
        self.smiles_bank = np.array(smiles_bank)
        self.train_embs = memory_bank.to(self.device)
        self.n_train_embs = self.train_embs.shape[0]
        
        if self.sampling_strategy == 'adaptive_mask':
            if self.num_candidates > self.n_train_embs:
                warnings.warn(f"Requested k={self.num_candidates} anchors, but only {self.n_train_embs} training samples exist. Using {self.n_train_embs}.")
                self.num_candidates = self.n_train_embs
            
            if self.adaptive_k_min > self.num_candidates:
                warnings.warn(f"adaptive_k_min ({self.adaptive_k_min}) > num_candidates ({self.num_candidates}). Clamping k_min to {self.num_candidates}")
                self.adaptive_k_min = self.num_candidates
            # Ensure k_min <= k_max after potential clamping of k_max
            self.adaptive_k_max = min(self.adaptive_k_max, self.num_candidates)
            if self.adaptive_k_min > self.adaptive_k_max:
                warnings.warn(f"adaptive_k_min ({self.adaptive_k_min}) > adaptive_k_max ({self.adaptive_k_max}). Clamping k_min to k_max ({self.adaptive_k_max})")
                self.adaptive_k_min = self.adaptive_k_max
                  

    def choose_multiple_anchors(self, query, return_smiles=False):
        k = self.num_candidates
        batch_size, latent_dim = query.shape
        attention_mask = None

        # --- Step 1: Find Top k candidates ---
        if self.anchor_metric == 'euclidean':
            scores = torch.cdist(query, self.train_embs, p=2)
            sorted_scores, indices_topk = torch.topk(scores, k, dim=1, largest=False, sorted=True)
            largest = False
        elif self.anchor_metric == 'cosine':
            query_norm = F.normalize(query, p=2, dim=1)
            train_norm = F.normalize(self.train_embs, p=2, dim=1)
            scores = torch.matmul(query_norm, train_norm.t())
            sorted_scores, indices_topk = torch.topk(scores, k, dim=1, largest=True, sorted=True)
            largest = True


        # --- Step 2: Sampling strategy ---
        if self.sampling_strategy == 'topk':
            anchor_weights = sorted_scores
            indices = indices_topk
        elif self.sampling_strategy == 'diverse_topm':
            anchor_weights, indices = self.sampling_by_diverse_topm(query, k)
        elif self.sampling_strategy == 'temperature':
            anchor_weights, indices = self.sampling_by_temperature(scores, k, largest)
        elif self.sampling_strategy == 'adaptive_mask':
            anchor_weights, indices, attention_mask = self.sampling_by_adaptive_mask(query, k, sorted_scores, indices_topk)

        # --- Step 3: Gather embeddings ---
        expanded_train = self.train_embs.unsqueeze(0).expand(batch_size, -1, -1)
        indices_expanded = indices.view(batch_size, k, 1).expand(-1, -1, latent_dim).long()
        candidates = torch.gather(expanded_train, 1, indices_expanded)

        if return_smiles:
            flat_indices = indices.view(-1).tolist()
            smiles_flat = [self.smiles_bank[idx] for idx in flat_indices]
            candidates_smiles = [
                smiles_flat[i * k: (i + 1) * k] for i in range(batch_size)
            ]
        else:
            candidates_smiles = None
        return candidates, anchor_weights, attention_mask, indices, candidates_smiles

    
    def sampling_by_temperature(self, scores, k, largest):
        if self.temperature <= 1e-8: # Effectively top-k for very low temp
            warnings.warn("Temperature is near zero, falling back to top-k sampling.")
            anchor_weights, indices = torch.topk(scores, k, dim=1, largest=largest, sorted=True)
        else:
            # Calculate logits based on scores and temperature
            if self.anchor_metric == 'euclidean':
                # Lower distance = higher probability => negate distance
                logits = -scores / self.temperature
            elif self.anchor_metric == 'cosine':
                # Higher similarity = higher probability
                # Scale similarity from [-1, 1] to [0, 1] first
                scaled_scores = (scores + 1.0) / 2.0
                logits = scaled_scores / self.temperature
            else: # Should not happen due to earlier check
                raise ValueError("Metric inconsistency for temperature sampling.")

            # Calculate probabilities using softmax
            probabilities = torch.softmax(logits, dim=-1) # shape: (batch_size, num_train)

            # Handle potential NaN/Inf from large logits and normalize row-wise
            probabilities = torch.nan_to_num(probabilities, nan=0.0, posinf=torch.finfo(probabilities.dtype).max)
            prob_sums = probabilities.sum(dim=-1, keepdim=True)
            probabilities = probabilities / prob_sums.clamp(min=1e-8) # Normalize to sum to 1

            # Sample indices without replacement using multinomial
            multinomial_success = False
            try:
                indices = torch.multinomial(probabilities, k, replacement=False) # shape: (batch_size, k)
                multinomial_success = True # Flag success
            except RuntimeError as e:
                print(f"RuntimeError in multinomial (batch, temp_sampling): {e}")
                print(f"Probabilities sum min/max: {probabilities.sum(dim=-1).min()}, {probabilities.sum(dim=-1).max()}")
                print(f"Any NaN/Inf in probabilities: {torch.isnan(probabilities).any()} / {torch.isinf(probabilities).any()}")
                print(f"Any negative probabilities: {(probabilities < 0).any()}")
                print("Falling back to topk sampling for this batch due to multinomial error.")
                anchor_weights, indices = torch.topk(scores, k, dim=1, largest=largest, sorted=True)

            if multinomial_success:
                anchor_weights = torch.gather(scores, 1, indices) # shape: (batch_size, k)
        return anchor_weights, indices
    
        
    def sampling_by_diverse_topm(self, query, k):
        batch_size, latent_dim = query.shape
        # This block needs the diverse selection logic based on M = k * diversity_factor
        # It computes its own final 'indices' and 'anchor_weights' and 'candidates'
        M = k * self.cfg.transducer.diverse_topm.diversity_factor
        if M > self.n_train_embs:
            # Warning already handled partially by clamping k, but check M vs actual k again
            M = min(M, self.n_train_embs) # Ensure M doesn't exceed available samples
            warnings.warn(f"Diversity pool size M clamped to {M}")
        if M < k:
                # Warning already handled by clamping k
                M = k # At least find k candidates

        # --- Find Top M candidates --- (Repeated calculation needed for diverse)
        if self.anchor_metric == 'euclidean':
            dists_m = torch.cdist(query, self.train_embs, p=2)
            topm_scores, topm_indices = torch.topk(dists_m, M, dim=1, largest=False, sorted=True)
        elif self.anchor_metric == 'cosine':
            # Recalculate similarity needed if not done above, or reuse 'scores' if M=k
            query_norm = F.normalize(query, p=2, dim=1)
            train_norm = F.normalize(self.train_embs, p=2, dim=1)
            scores_m = torch.matmul(query_norm, train_norm.t())
            topm_scores, topm_indices = torch.topk(scores_m, M, dim=1, largest=True, sorted=True)

        # Gather the embeddings of the top M candidates
        expanded_train_m = self.train_embs.unsqueeze(0).expand(batch_size, -1, -1)
        indices_expanded_m = topm_indices.unsqueeze(-1).expand(-1, -1, latent_dim)
        topm_embeddings = torch.gather(expanded_train_m, 1, indices_expanded_m)

        # --- Iteratively select k diverse anchors ---
        final_indices_diverse = torch.full((batch_size, k), -1, dtype=torch.long, device=self.device)
        final_scores_diverse = torch.zeros((batch_size, k), dtype=torch.float, device=self.device)
        final_embeddings_diverse = torch.zeros((batch_size, k, latent_dim), dtype=torch.float, device=self.device)

        # (Diverse selection loop - copied from previous correct version)
        for i in range(batch_size):
            current_topm_indices = topm_indices[i]
            current_topm_scores = topm_scores[i]
            current_topm_embeddings = topm_embeddings[i]
            selected_indices_mask = torch.zeros(M, dtype=torch.bool, device=self.device)
            current_final_indices = []
            current_final_scores = []
            current_final_embeddings = []
            first_idx_in_topm = 0
            selected_idx_original = current_topm_indices[first_idx_in_topm]
            current_final_indices.append(selected_idx_original)
            current_final_scores.append(current_topm_scores[first_idx_in_topm])
            current_final_embeddings.append(current_topm_embeddings[first_idx_in_topm].unsqueeze(0))
            selected_indices_mask[first_idx_in_topm] = True
            num_selected = 1
            while num_selected < k and num_selected < M:
                remaining_indices_in_topm = (~selected_indices_mask).nonzero().squeeze(-1)
                if len(remaining_indices_in_topm) == 0: break
                remaining_embeddings = current_topm_embeddings[remaining_indices_in_topm]
                selected_embedding_tensor = torch.cat(current_final_embeddings, dim=0)
                if self.anchor_metric == 'euclidean':
                    pairwise_metric_vals = torch.cdist(remaining_embeddings, selected_embedding_tensor, p=2)
                    min_dist_to_selected, _ = torch.min(pairwise_metric_vals, dim=1)
                    best_remaining_idx = torch.argmax(min_dist_to_selected)
                elif self.anchor_metric == 'cosine':
                    remaining_norm = F.normalize(remaining_embeddings, p=2, dim=1)
                    selected_norm = F.normalize(selected_embedding_tensor, p=2, dim=1)
                    pairwise_metric_vals = torch.matmul(remaining_norm, selected_norm.t())
                    max_sim_to_selected, _ = torch.max(pairwise_metric_vals, dim=1)
                    best_remaining_idx = torch.argmin(max_sim_to_selected)
                chosen_idx_in_topm = remaining_indices_in_topm[best_remaining_idx]
                selected_idx_original = current_topm_indices[chosen_idx_in_topm]
                current_final_indices.append(selected_idx_original)
                current_final_scores.append(current_topm_scores[chosen_idx_in_topm])
                current_final_embeddings.append(current_topm_embeddings[chosen_idx_in_topm].unsqueeze(0))
                selected_indices_mask[chosen_idx_in_topm] = True
                num_selected += 1
            num_actually_selected = len(current_final_indices)
            if num_actually_selected < k: # Pad if needed
                padding_needed = k - num_actually_selected
                current_final_indices.extend([current_final_indices[0]] * padding_needed)
                current_final_scores.extend([current_final_scores[0]] * padding_needed)
                padding_embeddings = current_final_embeddings[0].repeat(padding_needed, 1)
                current_final_embeddings.append(padding_embeddings)
            final_indices_diverse[i, :] = torch.tensor(current_final_indices[:k], device=self.device)
            final_scores_diverse[i, :] = torch.tensor(current_final_scores[:k], device=self.device)
            final_embeddings_diverse[i, :, :] = torch.cat(current_final_embeddings[:k], dim=0)

        # OVERWRITE results for diverse strategy
        indices = final_indices_diverse
        anchor_weights = final_scores_diverse
        
        return anchor_weights, indices
        
        
        
    def sampling_by_adaptive_mask(self, query, k, sorted_scores, indices_topk):
        batch_size, latent_dim = query.shape
        
        # Uses the 'indices_topk' and 'sorted_scores' calculated initially.
        anchor_weights = sorted_scores
        indices = indices_topk

        # Now, compute k_adaptive for each query and create the mask
        k_adaptive = torch.zeros(batch_size, dtype=torch.long, device=self.device)

        # Ensure k_min index is valid (using self.adaptive_k_min)
        k_min_idx = self.adaptive_k_min - 1
        if k_min_idx < 0: k_min_idx = 0
        # Ensure k_min_idx is within the bounds of available scores (0 to k-1)
        if k_min_idx >= sorted_scores.shape[1]: k_min_idx = sorted_scores.shape[1] - 1

        for i in range(batch_size):
            # Get the score of the k_min-th neighbor from the already sorted top-k scores
            score_k_min = sorted_scores[i, k_min_idx]

            if self.anchor_metric == 'euclidean':
                # Threshold is factor * distance_k_min
                threshold = score_k_min * self.cfg.transducer.adaptive_mask.density_threshold_factor
                # Count how many of the top k have distance <= threshold
                count = torch.sum(sorted_scores[i] <= threshold)
            elif self.anchor_metric == 'cosine':
                # Threshold is similarity_k_min / factor (for factor > 1)
                if self.cfg.transducer.adaptive_mask.density_threshold_factor <= 1e-6: # Avoid division by zero/small num
                        threshold = score_k_min # Use k_min score directly
                else:
                    threshold = score_k_min / self.cfg.transducer.adaptive_mask.density_threshold_factor

                # Ensure threshold is somewhat reasonable, e.g., clip?
                # Cosine similarity threshold should ideally be <= 1
                threshold = min(threshold, 1.0) # Clip threshold at 1.0

                # Count how many of the top k have similarity >= threshold
                count = torch.sum(sorted_scores[i] >= threshold)
            else: # Should not happen
                count = torch.tensor(self.adaptive_k_min, device=self.device)

            # Clamp between k_min and k_max
            k_adaptive[i] = torch.clamp(count, min=self.adaptive_k_min, max=self.adaptive_k_max)

        # Create the attention mask (True means mask out)
        # Shape: (batch_size, k)
        attention_mask = (torch.arange(k, device=self.device).expand(batch_size, k) >= k_adaptive.unsqueeze(1)).bool()
        
        return anchor_weights, indices, attention_mask
                 
            
def define_transducer(args, cfg, memory_bank, smiles_bank):
    transducer = DeltaDistributionTransducer(args, cfg, memory_bank, smiles_bank)
    return transducer


