class QCL:
    def __init__(self, n_qubits=4, circuit_depth=3, n_shots=2000):
        # Circuit parameters
        self.n_qubits = n_qubits
        self.circuit_depth = circuit_depth
        self.n_shots = n_shots
        self.n_params = n_qubits * circuit_depth * 3 + n_qubits * 2
        
        # Training parameters
        self.learning_rate = 0.3
        self.batch_size = 4
        self.n_epochs = 300
        self.gradient_epsilon = np.pi / 2
        self.early_stopping_patience = 20
        self.min_improvement = 0.001
        self.target_loss = 0.1
        self.lr_decay_rate = 0.97
        self.lr_decay_patience = 12
        
        # Setup
        self.backend = Aer.get_backend('qasm_simulator')
        self.trained_params = None
        self.feature_names = None
        
    def create_circuit(self, x, params):
        """Create the complete quantum circuit using measure_all()."""
        qc = QuantumCircuit(self.n_qubits)
        param_idx = 0
        
        # Input encoding
        for i in range(self.n_qubits):
            qc.ry(x[i] * np.pi, i)
        
        # Variational layers
        for layer in range(self.circuit_depth):
            # Parameterized rotations
            for i in range(self.n_qubits):
                qc.ry(params[param_idx], i)
                param_idx += 1
                qc.rz(params[param_idx], i)
                param_idx += 1
                qc.ry(params[param_idx], i)
                param_idx += 1
            
            # Entanglement
            qubit_combs = list(combinations(range(self.n_qubits), 2))
            for q1, q2 in qubit_combs:
                qc.cx(q1, q2)
            for q1, q2 in qubit_combs:
                qc.cx(q2, q1)
        
        # Output layer
        for i in range(self.n_qubits):
            qc.ry(params[param_idx], i)
            param_idx += 1
            qc.rz(params[param_idx], i)
            param_idx += 1
        
        # Measurements - automatic classical register creation
        qc.measure_all()
        return qc
    
    def get_probabilities(self, x, params):
        """Get class probabilities for input x."""
        qc = self.create_circuit(x, params)
        job = self.backend.run(qc, shots=self.n_shots)
        counts = job.result().get_counts()
        
        # Map 16 measurement outcomes to 3 classes
        class_counts = [0, 0, 0]
        for state, count in counts.items():
            class_idx = int(state, 2) % 3  # Map states 0-15 to classes 0-2
            class_counts[class_idx] += count
        
        # Convert to probabilities
        probs = np.array(class_counts) / self.n_shots
        return np.clip(probs, 1e-10, 1 - 1e-10)  # Avoid log(0)
    
    def compute_loss(self, params, x_batch, y_batch):
        """Compute cross-entropy loss."""
        total_loss = 0
        for x, y in zip(x_batch, y_batch):
            pred = self.get_probabilities(x, params)
            total_loss -= np.sum(y * np.log(pred))
        return total_loss / len(x_batch)
    
    def compute_gradient(self, params, x_batch, y_batch):
        """Compute gradient using parameter-shift rule."""
        gradient = np.zeros(self.n_params)
        
        for param_idx in range(self.n_params):
            # Forward shift
            params[param_idx] += self.gradient_epsilon
            loss_plus = self.compute_loss(params, x_batch, y_batch)
            
            # Backward shift
            params[param_idx] -= 2 * self.gradient_epsilon
            loss_minus = self.compute_loss(params, x_batch, y_batch)
            
            # Restore parameter
            params[param_idx] += self.gradient_epsilon
            
            # Gradient
            gradient[param_idx] = (loss_plus - loss_minus) / 2
            
        return gradient
    
    def train(self, x_train, y_train, x_val=None, y_val=None):
        """Train the quantum classifier with early stopping and learning rate decay."""
        # Initialize parameters
        params = np.random.random(self.n_params) * 2 * np.pi
        best_params = params.copy()
        best_loss = float('inf')
        learning_rate = self.learning_rate
        no_improvement = 0
        
        train_losses = []
        val_losses = []
        
        print("Starting training...")
        start_time = time.time()
        
        n_samples = len(x_train)
        n_batches = n_samples // self.batch_size
        if n_samples % self.batch_size != 0:
            n_batches += 1
        
        for epoch in range(self.n_epochs):
            # Shuffle data
            indices = np.random.permutation(n_samples)
            x_shuffled = x_train[indices]
            y_shuffled = y_train[indices]
            epoch_loss = 0
            
            # Mini-batch training
            for batch_idx in range(n_batches):
                start_idx = batch_idx * self.batch_size
                end_idx = min(start_idx + self.batch_size, n_samples)
                x_batch = x_shuffled[start_idx:end_idx]
                y_batch = y_shuffled[start_idx:end_idx]
                
                batch_size = end_idx - start_idx
                gradient = self.compute_gradient(params, x_batch, y_batch)
                params -= learning_rate * gradient / batch_size
                
                batch_loss = self.compute_loss(params, x_batch, y_batch)
                epoch_loss += batch_loss
            
            epoch_loss /= n_batches
            train_losses.append(epoch_loss)
            
            # Validation and early stopping
            if x_val is not None and y_val is not None:
                val_loss = self.compute_loss(params, x_val, y_val)
                val_losses.append(val_loss)
                
                if val_loss < best_loss - self.min_improvement:
                    best_loss = val_loss
                    best_params = params.copy()
                    no_improvement = 0
                else:
                    no_improvement += 1
                    
                    # Learning rate decay
                    if no_improvement >= self.lr_decay_patience:
                        learning_rate *= self.lr_decay_rate
                        print(f"\nLearning rate decayed to: {learning_rate:.6f}")
                
                # Early stopping
                if no_improvement >= self.early_stopping_patience:
                    print(f"\nEarly stopping: No improvement for {self.early_stopping_patience} epochs.")
                    break
                
                # Target loss reached
                if val_loss < self.target_loss:
                    print("\nReached target loss value!")
                    break
            
            # Progress update
            if epoch % 10 == 0:
                elapsed = time.time() - start_time
                print(f"\nEpoch {epoch}/{self.n_epochs}:")
                print(f"Train Loss: {epoch_loss:.6f}")
                if x_val is not None:
                    print(f"Val Loss: {val_losses[-1]:.6f}")
                print(f"Learning Rate: {learning_rate:.6f}")
                print(f"Total Time: {elapsed:.2f}s")
        
        self.trained_params = best_params
        self._plot_training_curves(train_losses, val_losses)
        return best_params, best_loss
    
    def predict(self, x, params=None):
        """Predict class labels."""
        if params is None:
            if self.trained_params is None:
                raise ValueError("Model not trained. Call train() first.")
            params = self.trained_params
        
        predictions = []
        for sample in x:
            probs = self.get_probabilities(sample, params)
            predictions.append(np.argmax(probs))
        return np.array(predictions)
    
    def predict_proba(self, x, params=None):
        """Predict class probabilities."""
        if params is None:
            if self.trained_params is None:
                raise ValueError("Model not trained. Call train() first.")
            params = self.trained_params
        
        probabilities = []
        for sample in x:
            probs = self.get_probabilities(sample, params)
            probabilities.append(probs)
        return np.array(probabilities)
    
    def evaluate(self, x, y, params=None, dataset_name="Test"):
        """Evaluate model performance."""
        predictions = self.predict(x, params)
        y_true = np.argmax(y, axis=1) if y.ndim > 1 else y
        
        accuracy = np.mean(predictions == y_true)
        loss = self.compute_loss(params if params is not None else self.trained_params, x, y)
        
        print(f"\n{dataset_name} Results:")
        print(f"Accuracy: {accuracy * 100:.2f}%")
        print(f"Loss: {loss:.6f}")
        
        # Confusion matrix
        cm = confusion_matrix(y_true, predictions)
        self._plot_confusion_matrix(cm, dataset_name)
        
        return accuracy, loss, predictions
    
    def save_parameters(self, parameters=None, filename="new_iris_trained_params.npy"):
        """Save trained parameters to file."""
        if parameters is None:
            if self.trained_params is None:
                raise ValueError("No trained parameters to save. Train the model first.")
            parameters = self.trained_params
        
        np.save(filename, parameters)
        print(f"Parameters saved to {filename}")
    
    def load_parameters(self, filename="new_iris_trained_params.npy"):
        """Load trained parameters from file."""
        try:
            parameters = np.load(filename)
            self.trained_params = parameters
            print(f"Parameters loaded from {filename}")
            return parameters
        except FileNotFoundError:
            print(f"File {filename} not found.")
            return None
        except Exception as e:
            print(f"Error loading parameters: {e}")
            return None
    
    def set_feature_names(self, feature_names):
        """Set feature names for XAI methods."""
        self.feature_names = feature_names
        print(f"Feature names set: {feature_names}")


class QuantumPerturbationExplainer:
    def __init__(self, qcl_model, target_coeff=1.0, l1_coeff=0.1, fidelity_coeff=0.5, 
                 entanglement_coeff=0.3, superposition_coeff=0.2):
        self.qcl = qcl_model
        self.target_coeff = target_coeff
        self.l1_coeff = l1_coeff
        self.fidelity_coeff = fidelity_coeff
        self.entanglement_coeff = entanglement_coeff
        self.superposition_coeff = superposition_coeff

    def explain_instances(self, x_test, target_classes=None, n_iterations=1000, lr=0.001,
                         weight_decay=1e-5, progressive_training=True, verbose=True):
        """
        Explain multiple instances using perturbation-based method.
        
        Args:
            x_test: Test instances (2D array: [n_instances, n_features])
            target_classes: Target class indices for each instance (if None, uses predicted classes)
            n_iterations: Number of optimization iterations per instance
            lr: Learning rate for Adam optimizer
            weight_decay: Weight decay for regularization
            progressive_training: Whether to progressively add loss terms
            verbose: Whether to show progress bar
        
        Returns:
            masks: Feature importance masks for all instances
            loss_histories: Loss histories for all instances
            summary_stats: Summary statistics across all instances
        """
        x_test = np.array(x_test)
        n_instances = x_test.shape[0]
        print(f"We are processing {n_instances} instances")
        
        # If no target classes provided, use predicted classes
        if target_classes is None:
            target_classes = []
            for i in range(n_instances):
                probs = self.qcl.get_probabilities(x_test[i], self.qcl.trained_params)
                target_classes.append(np.argmax(probs))
        
        masks = []
        loss_histories = []
        
        # Process each instance
        iterator = tqdm(range(n_instances), desc="Explaining instances") if verbose else range(n_instances)
        
        for i in iterator:
            if verbose and not isinstance(iterator, range):
                iterator.set_description(f"Explaining instance {i+1}/{n_instances}")
            
            mask, loss_history = self.explain_instance(
                x_test[i], 
                target_classes[i], 
                n_iterations=n_iterations,
                lr=lr,
                weight_decay=weight_decay,
                progressive_training=progressive_training
            )
            
            masks.append(mask)
            loss_histories.append(loss_history)
        
        # Calculate summary statistics
        summary_stats = self._calculate_summary_stats(masks, loss_histories)
        
        return masks, loss_histories, summary_stats
        
    def explain_instance(self, x_instance, target_class, n_iterations=1000, lr=0.001, 
                        weight_decay=1e-5, progressive_training=True):
        """
        Explain a single instance using perturbation-based method.
        
        Args:
            x_instance: Input instance to explain
            target_class: Target class index
            n_iterations: Number of optimization iterations
            lr: Learning rate for Adam optimizer
            weight_decay: Weight decay for regularization
            progressive_training: Whether to progressively add loss terms
        """
        # Initialize mask
        mask = torch.randn(len(x_instance), requires_grad=True)
        optimizer = torch.optim.Adam([mask], lr=lr, weight_decay=weight_decay)
        
        # Convert inputs to torch tensors
        x_tensor = torch.tensor(x_instance, dtype=torch.float32)
        target = torch.tensor(target_class, dtype=torch.long)
        
        # Get original prediction
        original_probs = self.qcl.get_probabilities(x_instance, self.qcl.trained_params)
        original_probs_tensor = torch.tensor(original_probs, dtype=torch.float32)
        
        # Loss history
        loss_history = {
            'total': [], 'target': [], 'l1': [], 'fidelity': [], 
            'entanglement': [], 'superposition': []
        }
        
        # Progressive training phases
        if progressive_training:
            phases = [
                {'iterations': n_iterations // 4, 'coeffs': [1, 0, 0, 0, 0]},  # Only target
                {'iterations': n_iterations // 4, 'coeffs': [1, 1, 0, 0, 0]},  # + L1
                {'iterations': n_iterations // 4, 'coeffs': [1, 1, 1, 0, 0]},  # + fidelity
                {'iterations': n_iterations // 4, 'coeffs': [1, 1, 1, 1, 1]},  # All terms
            ]
        else:
            phases = [{'iterations': n_iterations, 'coeffs': [1, 1, 1, 1, 1]}]
        
        print(f"Starting perturbation-based explanation for target class {target_class}")
        
        iteration = 0
        for phase_idx, phase in enumerate(phases):
            print(f"Phase {phase_idx + 1}: {phase['iterations']} iterations")
            
            for _ in range(phase['iterations']):
                optimizer.zero_grad()
                
                # Apply sigmoid to get mask values between 0 and 1
                sigmoid_mask = torch.sigmoid(mask)
                
                # Apply perturbation
                perturbed_x = x_tensor * sigmoid_mask
                
                # Get prediction for perturbed input
                perturbed_probs = self.qcl.get_probabilities(perturbed_x.detach().numpy(), 
                                                           self.qcl.trained_params)
                perturbed_probs_tensor = torch.tensor(perturbed_probs, dtype=torch.float32)
                
                # Calculate loss components
                target_loss = self._calculate_target_loss(perturbed_probs_tensor, target)
                l1_loss = torch.mean(torch.abs(1 - sigmoid_mask))  # Prefer keeping features
                fidelity_loss = self._calculate_fidelity_loss(original_probs_tensor, perturbed_probs_tensor, method='jensen_shannon')
                entanglement_loss = self._calculate_entanglement_loss(perturbed_x)
                superposition_loss = self._calculate_superposition_loss(sigmoid_mask)
                
                # Progressive coefficients
                coeffs = phase['coeffs']
                scaled_target_coeff = self.target_coeff * coeffs[0]
                scaled_l1_coeff = self.l1_coeff * coeffs[1]
                scaled_fidelity_coeff = self.fidelity_coeff * coeffs[2]
                scaled_entanglement_coeff = self.entanglement_coeff * coeffs[3]
                scaled_superposition_coeff = self.superposition_coeff * coeffs[4]
                
                # Total loss
                loss = (scaled_target_coeff * target_loss +
                       scaled_l1_coeff * l1_loss +
                       scaled_fidelity_coeff * fidelity_loss +
                       scaled_entanglement_coeff * entanglement_loss +
                       scaled_superposition_coeff * superposition_loss)
                
                # Backward pass
                loss.backward()
                optimizer.step()
                
                # Store loss history
                loss_history['total'].append(loss.item())
                loss_history['target'].append((scaled_target_coeff * target_loss).item())
                loss_history['l1'].append((scaled_l1_coeff * l1_loss).item())
                loss_history['fidelity'].append((scaled_fidelity_coeff * fidelity_loss).item())
                loss_history['entanglement'].append((scaled_entanglement_coeff * entanglement_loss).item())
                loss_history['superposition'].append((scaled_superposition_coeff * superposition_loss).item())
                
                iteration += 1
                
                if iteration % 200 == 0:
                    print(f"Iteration {iteration}: Loss = {loss.item():.4f}")
        
        # Get final mask
        final_mask = torch.sigmoid(mask).detach().numpy()
        
        return final_mask, loss_history
    
    def _calculate_convergence_rate(self, loss_histories):
        """Calculate average convergence rate across all instances."""
        convergence_rates = []
        for hist in loss_histories:
            losses = hist['total']
            if len(losses) > 100:
                # Calculate improvement in last 20% vs first 20%
                n_early = len(losses) // 5
                n_late = len(losses) // 5
                early_loss = np.mean(losses[:n_early])
                late_loss = np.mean(losses[-n_late:])
                if early_loss > 0:
                    convergence_rate = (early_loss - late_loss) / early_loss
                    convergence_rates.append(convergence_rate)
        
        return np.mean(convergence_rates) if convergence_rates else 0.0
    
    def _calculate_summary_stats(self, masks, loss_histories):
        """Calculate summary statistics across all instances."""
        masks_array = np.array(masks)
        n_instances, n_features = masks_array.shape
        
        # Feature importance statistics
        feature_importance_stats = {
            'mean': np.mean(masks_array, axis=0),
            'std': np.std(masks_array, axis=0),
            'median': np.median(masks_array, axis=0),
            'min': np.min(masks_array, axis=0),
            'max': np.max(masks_array, axis=0),
            'q25': np.percentile(masks_array, 25, axis=0),
            'q75': np.percentile(masks_array, 75, axis=0)
        }
        
        # Loss statistics
        final_losses = [hist['total'][-1] for hist in loss_histories]
        loss_stats = {
            'mean_final_loss': np.mean(final_losses),
            'std_final_loss': np.std(final_losses),
            'convergence_rate': self._calculate_convergence_rate(loss_histories)
        }
        
        # Sparsity statistics
        sparsity_stats = {
            'mean_sparsity': np.mean([np.sum(mask > 0.5) / len(mask) for mask in masks]),
            'std_sparsity': np.std([np.sum(mask > 0.5) / len(mask) for mask in masks])
        }
        
        return {
            'feature_importance': feature_importance_stats,
            'loss': loss_stats,
            'sparsity': sparsity_stats,
            'n_instances': n_instances,
            'n_features': n_features
        }
    
    def _calculate_target_loss(self, probs, target):
        """Calculate target loss (negative log likelihood for target class)."""
        return -torch.log(probs[target] + 1e-10)
    
    def _calculate_fidelity_loss(self, original_probs, perturbed_probs, method='l2'):
        """Calculate fidelity loss between original and perturbed predictions."""
        # for state vector, uncomment the following code
        # inner_product = torch.sum(torch.conj(original_state) * perturbed_state, dim=-1)
        #     fidelity = torch.abs(inner_product) ** 2
        if method == 'l2':
            return F.mse_loss(perturbed_probs, original_probs)
        elif method == 'jensen_shannon':
            # Convert to numpy for JS divergence
            orig_np = original_probs.detach().numpy()
            pert_np = perturbed_probs.detach().numpy()
            js_div = jensenshannon(orig_np, pert_np)
            return torch.tensor(js_div, dtype=torch.float32)
        else:
            raise ValueError("Method must be 'l2' or 'jensen_shannon'")
    
    def _calculate_entanglement_loss(self, perturbed_x):
        """Calculate entanglement loss using correlation matrix."""
        if len(perturbed_x) < 2:
            return torch.tensor(0.0)
        
        # Create correlation matrix of all features
        x_centered = perturbed_x - torch.mean(perturbed_x)
        corr_matrix = torch.outer(x_centered, x_centered) / (torch.std(perturbed_x) ** 2 + 1e-10)
        
        # Off-diagonal elements represent correlations (entanglement) - penalize high correlations
        mask = ~torch.eye(len(perturbed_x), dtype=bool)
        off_diagonal = corr_matrix[mask]
        
        # Minimize correlations (encourage independence)
        return torch.mean(torch.abs(off_diagonal))
    
    def _calculate_superposition_loss(self, sigmoid_mask):
        """Calculate superposition loss using participation ratio and entropy."""
        # Participation ratio: measures how many features are "active"
        participation_ratio = (torch.sum(sigmoid_mask) ** 2) / torch.sum(sigmoid_mask ** 2)
        participation_loss = 1.0 / (participation_ratio + 1e-10)  # Encourage sparsity
        
        # Shannon entropy: measures uncertainty in mask
        p = sigmoid_mask + 1e-10
        entropy = -torch.sum(p * torch.log(p) + (1-p) * torch.log(1-p))
        
        return participation_loss + 0.1 * entropy

