import numpy as np
from config import (
    USE_GPU,
    GRAD_CLIP_NORM,
    LR_DECAY_STEPS,
    LR_DECAY_RATE,
)
from qulacs import QuantumState, Observable, QuantumCircuit, ParametricQuantumCircuit
from qulacs.state import inner_product
try:
    from qulacs import QuantumStateGpu
except ImportError:  # CPU-only installation
    QuantumStateGpu = QuantumState
from sklearn.metrics import log_loss
from scipy.optimize import minimize
from qcl_utils import create_time_evol_gate, min_max_scaling, softmax

from tqdm import tqdm
from joblib import Parallel, delayed

class QclClassification:
    """Solve classification problems using quantum circuit learning."""
    def __init__(self, nqubit, c_depth, num_class=None):
        """
        :param nqubit: Number of qubits; ``2**nqubit`` must be at least the number of classes.
        :param c_depth: Circuit depth.
        :param num_class: Number of classes. ``None`` lets ``fit`` infer this from ``y_list``.
        """
        self.nqubit = nqubit
        self.c_depth = c_depth

        self.input_state_list = []  # List of input states |ψ_in⟩
        self.theta = []  # Flat list of parameters θ

        self.output_gate = None  # U_out

        self.num_class = num_class  # Number of classes

        self.obs = None
        if self.num_class is not None and self.num_class <= self.nqubit:
            self._initialize_observable()

    def _initialize_observable(self):
        """Prepare observables based on ``num_class``.

        Create ``num_class`` ``Z`` observables to use expectation-based predictions.
        This decoding method only works when ``num_class <= nqubit``.
        """
        if self.num_class is None:
            return
        if self.num_class > self.nqubit:
            raise ValueError("num_class exceeds number of qubits for observable-based decoding")
        obs = [Observable(self.nqubit) for _ in range(self.num_class)]
        for i in range(len(obs)):
            obs[i].add_operator(1.0, f"Z {i}")
        self.obs = obs

    def create_input_gate(self, x):
        """Generate the quantum circuit that encodes the input ``x``."""
        # Elements of x are assumed to lie within [-1, 1]
        u = QuantumCircuit(self.nqubit)

        angle_y = np.arcsin(x)
        angle_z = np.arccos(x**2)

        for i in range(self.nqubit):
            idx = i % len(x)
            u.add_RY_gate(i, angle_y[idx])
            u.add_RZ_gate(i, angle_z[idx])
        
        return u

    def set_input_state(self, x_list):
        """Create the list of encoded input quantum states."""
        x_list_normalized = min_max_scaling(x_list)  # Scale each feature of x to [-1, 1]
        
        st_list = []
        
        state_cls = QuantumStateGpu if USE_GPU else QuantumState

        for x in x_list_normalized:
            st = state_cls(self.nqubit)
            input_gate = self.create_input_gate(x)
            input_gate.update_quantum_state(st)
            st_list.append(st.copy())
        self.input_state_list = st_list

    def create_initial_output_gate(self):
        """Assemble the output gate ``U_out`` and initialize its parameters."""
        u_out = ParametricQuantumCircuit(self.nqubit)
        time_evol_gate = create_time_evol_gate(self.nqubit)
        theta = 2.0 * np.pi * np.random.rand(self.c_depth, self.nqubit, 3)
        self.theta = theta.flatten()
        for d in range(self.c_depth):
            u_out.add_gate(time_evol_gate)
            for i in range(self.nqubit):
                u_out.add_parametric_RX_gate(i, theta[d, i, 0])
                u_out.add_parametric_RZ_gate(i, theta[d, i, 1])
                u_out.add_parametric_RX_gate(i, theta[d, i, 2])
        self.output_gate = u_out
    
    def update_output_gate(self, theta):
        """Update ``U_out`` with parameters ``θ``."""
        self.theta = theta
        parameter_count = len(self.theta)
        for i in range(parameter_count):
            self.output_gate.set_parameter(i, self.theta[i])

    def get_output_gate_parameter(self):
        """Retrieve the parameters ``θ`` for ``U_out``."""
        parameter_count = self.output_gate.get_parameter_count()
        theta = [self.output_gate.get_parameter(ind) for ind in range(parameter_count)]
        return np.array(theta)

    def pred(self, theta):
        """Compute the model outputs for ``x_list``."""

        # Prepare input states
        # st_list = self.input_state_list
        st_list = [st.copy() for st in self.input_state_list]  # Copy each element so the list copy is effectively deep
        # Update U_out
        self.update_output_gate(theta)

        res = []
        # 出力状態計算 & 観測
        for st in st_list:
            # Update the state with U_out
            self.output_gate.update_quantum_state(st)
            # Model output
            r = [o.get_expectation_value(st) for o in self.obs]  # 出力多次元ver
            r = softmax(r)
            res.append(r.tolist())
        return np.array(res)

    def pred_amplitude(self, x_list):
        """Return class probabilities using state amplitudes."""

        x_scaled = min_max_scaling(x_list)
        res = []
        for x in x_scaled:
            circuit = self.output_gate.copy()
            gate = self.create_input_gate(x)
            for idx in reversed(range(gate.get_gate_count())):
                circuit.add_gate(gate.get_gate(idx), position=0)

            state = QuantumState(self.nqubit)
            state.set_zero_state()
            circuit.update_quantum_state(state)

            probs = []
            for cls in range(self.num_class):
                target = QuantumState(self.nqubit)
                target.set_computational_basis(cls)
                amp = inner_product(target, state)
                probs.append(abs(amp) ** 2)
            res.append(softmax(np.log(np.array(probs) + 1e-10)))
        return np.array(res)

    def cost_func(self, theta):
        """Compute the cost function value.

        :param theta: List of rotation-gate angles ``θ``.
        """

        y_pred = self.pred(theta)

        # cross-entropy loss
        loss = log_loss(self.y_list, y_pred)
        
        return loss

    # for BFGS
    def B_grad(self, theta):
        # Return the list of dB/dθ values
        theta_plus = [theta.copy() + np.eye(len(theta))[i] * np.pi / 2. for i in range(len(theta))]
        theta_minus = [theta.copy() - np.eye(len(theta))[i] * np.pi / 2. for i in range(len(theta))]

        grad = []
        for i in tqdm(range(len(theta)), desc="param", leave=False):
            grad.append((self.pred(theta_plus[i]) - self.pred(theta_minus[i])) / 2.)

        return np.array(grad)

    # for BFGS
    def cost_func_grad(self, theta):
        y_minus_t = self.pred(theta) - self.y_list
        B_gr_list = self.B_grad(theta)
        grad = [np.sum(y_minus_t * B_gr) for B_gr in B_gr_list]
        return np.array(grad)

    def fit(self, x_list, y_list, maxiter=1000):
        """
        :param x_list: Training inputs ``x``.
        :param y_list: Training targets ``y``.
        :param maxiter: Number of iterations for ``scipy.optimize.minimize``.
        :return: Loss value after training.
        :return: Optimized parameter values ``θ``.
        """

        # Determine num_class from y_list if not preset
        if self.num_class is None:
            self.num_class = y_list.shape[1]
        if self.num_class > 2 ** self.nqubit:
            raise ValueError("num_class exceeds representable classes for given qubits")
        if self.obs is None and self.num_class <= self.nqubit:
            self._initialize_observable()
        # Ensure that y_list matches num_class in dimensionality
        if y_list.shape[1] != self.num_class:
            raise ValueError("y_list and num_class mismatch")

        # Prepare initial states
        self.set_input_state(x_list)

        # Create a random U_out
        self.create_initial_output_gate()
        theta_init = self.theta

        # Store ground-truth labels
        self.y_list = y_list

        # for callbacks
        self.n_iter = 0
        self.maxiter = maxiter
        
        print("Initial parameter:")
        print()
        print(f"Initial value of cost function:  {self.cost_func(self.theta):.4f}")
        print()
        print('============================================================')
        print("Iteration count...")
        result = minimize(self.cost_func,
                          self.theta,
                          # method='Nelder-Mead',
                          method='BFGS',
                          jac=self.cost_func_grad,
                          options={"maxiter":maxiter},
                          callback=self.callbackF)
        theta_opt = self.theta
        print('============================================================')
        print()
        print("Optimized parameter:")
        print()
        print(f"Final value of cost function:  {self.cost_func(self.theta):.4f}")
        print()
        return result, theta_init, theta_opt

    def fit_backprop_inner_product(self, x_list, y_label, lr=0.1, n_iter=100):
        """Gradient descent training using ``backprop_inner_product``.

        Parameters
        ----------
        x_list : array-like
            Training features.
        y_label : array-like
            Class labels as integers or one-hot vectors.
        lr : float, default 0.1
            Learning rate for gradient descent.
        n_iter : int, default 100
            Number of optimization steps.
        """

        labels = np.argmax(y_label, axis=1) if y_label.ndim > 1 else y_label
        if self.num_class is None:
            self.num_class = int(labels.max()) + 1
        if self.num_class > 2 ** self.nqubit:
            raise ValueError("num_class exceeds representable classes for given qubits")

        x_scaled = min_max_scaling(x_list)
        input_gates = [self.create_input_gate(x) for x in x_scaled]

        self.create_initial_output_gate()

        for step in range(n_iter):
            total_grad = np.zeros_like(self.theta)
            total_loss = 0.0

            for gate, label in zip(input_gates, labels):
                circuit = self.output_gate.copy()
                for idx in reversed(range(gate.get_gate_count())):
                    circuit.add_gate(gate.get_gate(idx), position=0)

                state = QuantumState(self.nqubit)
                state.set_zero_state()
                circuit.update_quantum_state(state)

                target = QuantumState(self.nqubit)
                target.set_computational_basis(label)

                amp = inner_product(target, state)
                prob = abs(amp) ** 2
                total_loss += -np.log(prob + 1e-10)

                grads_complex = circuit.backprop_inner_product(target)
                grads = [
                    -2 * np.real(np.conj(amp) * dpsi) / (prob + 1e-10)
                    for dpsi in grads_complex
                ]
                print(grads)
                total_grad += np.array(grads)

            total_grad /= len(labels)
            for i, grad in enumerate(total_grad):
                theta = self.output_gate.get_parameter(i)
                self.output_gate.set_parameter(i, theta - lr * grad)
            self.theta = self.get_output_gate_parameter()

            if step % 10 == 0 or step == n_iter - 1:
                print(f"[{step:03d}] loss={total_loss/len(labels):.6f}")

        return self.theta

    def fit_llp_inner_product(
        self,
        x_list,
        bag_sampler,
        teacher_probs,
        lr=0.1,
        n_iter=100,
        loss="ce",
        n_jobs=1,
        return_history=False,
    ):
        """Train using bag-level label proportions.

        Parameters
        ----------
        x_list : array-like
            Full training features.
        bag_sampler : Iterable of index lists
            Defines the bags generated by ``create_fixed_proportion_batches``.
        teacher_probs : array-like of shape (n_bags, num_class)
            Label proportions for each bag.
        lr : float, default 0.1
            Learning rate.
        n_iter : int, default 100
            Number of optimization steps.
        loss : {"ce", "kl"}, default "ce"
            Loss function to use (cross entropy or KL divergence).
        n_jobs : int, default 1
            Number of parallel jobs when computing bag gradients.

        Notes
        -----
        The learning rate can be decayed every ``LR_DECAY_STEPS`` steps
        by ``LR_DECAY_RATE``. Gradients may also be clipped to
        ``GRAD_CLIP_NORM`` before each parameter update. These constants
        are defined in ``config.py``.

        """

        teacher = (
            teacher_probs.numpy() if hasattr(teacher_probs, "numpy") else teacher_probs
        )

        if self.num_class is None:
            self.num_class = teacher.shape[1]
        if self.num_class > 2 ** self.nqubit:
            raise ValueError("num_class exceeds representable classes for given qubits")

        x_scaled = min_max_scaling(x_list)
        input_gates = [self.create_input_gate(x) for x in x_scaled]

        self.create_initial_output_gate()
        print("self.theta.shape", self.theta.shape)

        bag_list = [indices for indices in bag_sampler if len(indices) > 0]

        losses = []
        best_loss = float("inf")
        best_step = 0
        best_params = self.get_output_gate_parameter().copy()
        
        def probs_to_target_diag(t_probs: np.ndarray, nqubit: int) -> np.ndarray:
            """
            Create the diagonal of the target density matrix ρ_target from class probs.

            ρ_target = Σ_i p_i |i⟩⟨i|
            We return only its diagonal (length 2**nqubit), since we only need ρ_target|ψ⟩
            and ⟨ψ|ρ_target|ψ⟩, both computable with this diagonal.
            """
            dim = 2 ** nqubit
            diag = np.zeros(dim, dtype=np.float64)
            indices = list(range(len(t_probs)))
            if len(indices) == 0 or max(indices) >= dim:
                raise ValueError(
                    f"Class basis indices {indices} exceed dimension {dim} for {nqubit} qubits"
                )
            diag[indices] = t_probs  # place class probs on those computational basis states
            return diag  # this represents ρ_target's diagonal

        def bag_grad_and_loss(bag_indices, t_probs):
            eps = 1e-10
            bag_size = len(bag_indices)

            # ρ_target (diagonal only)
            rho_t_diag = probs_to_target_diag(t_probs, self.nqubit)  # shape: (2**n,)

            # holders
            state_vecs = []   # |ψ_i⟩
            g_lists   = []    # g'_i = < (ρ_target ψ_i) | ∂ψ_i/∂θ_j >
            quad_vals = []    # ⟨ψ_i|ρ_target|ψ_i⟩

            for idx in bag_indices:
                gate = input_gates[idx]
                circuit = self.output_gate.copy()
                for gidx in reversed(range(gate.get_gate_count())):
                    circuit.add_gate(gate.get_gate(gidx), position=0)

                # forward state |ψ_i⟩
                state = QuantumState(self.nqubit)
                state.set_zero_state()
                circuit.update_quantum_state(state)
                psi = state.get_vector()  # np.ndarray (complex), shape (2**n,)
                state_vecs.append(psi)

                # F term per sample: ⟨ψ_i|ρ_target|ψ_i⟩
                # since ρ_target is diagonal, ρ_target|ψ_i⟩ = rho_t_diag * ψ_i  (elementwise)
                rho_t_psi = rho_t_diag * psi
                quad_vals.append(np.vdot(psi, rho_t_psi).real)

                # gradient inner products: < (ρ_target ψ_i) | ∂ψ_i/∂θ_j >
                target = QuantumState(self.nqubit)
                target.load(rho_t_psi)  # not necessarily normalized; inner product scales linearly
                g_i = circuit.backprop_inner_product(target)  # List[complex], one per parameter
                g_lists.append(g_i)

            # F = (1/m) Σ ⟨ψ_i|ρ_target|ψ_i⟩
            F = float(np.mean(quad_vals))
            total_loss = -np.log(F + eps)

            # ∂F/∂θ_j = (2/m) Σ Re( <∂ψ_i/∂θ_j | ρ_target | ψ_i> )
            # backprop returns <target | ∂ψ_i/∂θ_j> with target = (ρ_target|ψ_i>)
            inv_norm = -(2.0 / (F + eps) / bag_size)

            bag_grad = np.zeros_like(self.theta, dtype=float)
            for i in range(bag_size):
                g_i = np.asarray(g_lists[i])  # complex, shape: (n_params,)
                bag_grad += inv_norm * np.real(g_i)

            return total_loss, bag_grad


        for step in tqdm(range(n_iter), desc=f"Learning iteration"):
            results = Parallel(n_jobs=n_jobs, backend="threading")(
                delayed(bag_grad_and_loss)(indices, teacher[i])
                for i, indices in enumerate(bag_list)
            )

            losses_step, grads = zip(*results)
            avg_grad = np.mean(grads, axis=0)
            avg_loss = np.nanmean(losses_step)

            losses.append(avg_loss)

            if avg_loss < best_loss:
                best_loss = avg_loss
                best_params = self.get_output_gate_parameter().copy()
                best_step = step

            # Optional gradient clipping
            if GRAD_CLIP_NORM is not None:
                grad_norm = np.linalg.norm(avg_grad)
                if grad_norm > GRAD_CLIP_NORM:
                    avg_grad = avg_grad * (GRAD_CLIP_NORM / (grad_norm + 1e-12))

            for i, grad in enumerate(avg_grad):
                theta = self.output_gate.get_parameter(i)
                self.output_gate.set_parameter(i, theta - lr * grad)
            self.theta = self.get_output_gate_parameter()

            # Simple learning rate scheduler
            if LR_DECAY_STEPS and (step + 1) % LR_DECAY_STEPS == 0:
                lr *= LR_DECAY_RATE

            if step % 10 == 0 or step == n_iter - 1:
                print(f"[{step:03d}] loss={avg_loss:.6f} lr={lr:.6f}")

        if return_history:
            history = {
                "loss_history": losses,
                "best_loss": best_loss,
                "best_step": best_step,
                "best_params": best_params,
            }
            return self.theta, history
        return self.theta

    def callbackF(self, theta):
        self.n_iter = self.n_iter + 1
        if 10 * self.n_iter % self.maxiter == 0:
            print(f"Iteration: {self.n_iter} / {self.maxiter},   Value of cost_func: {self.cost_func(theta):.4f}")
