import os

import torch
import numpy as np
from scipy.optimize import minimize, NonlinearConstraint
from functools import reduce
import pennylane as qml
from torchvision.utils import save_image
import time
from tools.model_loader import cidx2qidx
from tools import Log


def Z_on_qubit(i, n_qubits=8):
    """Return 2^n × 2^n matrix for PauliZ acting on qubit i (0-indexed)"""
    Z = np.array([[1, 0], [0, -1]])
    I = np.eye(2)

    ops = [Z if j == i else I for j in range(n_qubits)]
    return reduce(np.kron, ops)


def proj_on_qubit(i, cidx, n_qubits=8):
    """ Return |1><1| or |0><0| projector on qubit i """
    P = np.array([[0, 0],
                   [0, 1]]) if cidx == 1 else np.array([[1, 0], [0, 0]])
    I = np.eye(2)

    ops = [I] * n_qubits
    ops[i] = P
    proj = ops[0]
    for op in ops[1:]:
        proj = np.kron(proj, op)
    return proj


def output_proj(model_n, cidx, n_qubits=8, class_idx=[0, 1]):
    return Z_on_qubit(cidx2qidx(model_n, cidx, n_qubits, class_idx), n_qubits=n_qubits)


def state2img(phi, img_shape=(16, 16, 1)):
    # img = np.abs(phi)
    img = phi.reshape(img_shape).permute(2, 0, 1)
    img = img / img.max()
    return img


def pad_img(img, cir_dim):
    img_dim = img.shape[0] if len(img.shape) == 1 else img.shape[1]
    if img_dim == cir_dim:
        return img / np.linalg.norm(img, axis=1, keepdims=True), False
    else:
        rho = np.zeros(cir_dim)
        rho[:, :img_dim] = img
        rho /= np.linalg.norm(rho, axis=1, keepdims=True)
        return rho, True


def optimal_robust_bound(rho, M_l, M_k, U, log, n=256, need_pad=False, eps=1e-5):
    m = U.shape[0]

    def objective(phi):
        inner = np.vdot(rho, phi)
        return 1 - np.abs(inner) ** 2

    def constraint_norm(phi):
        return np.sum(np.abs(phi)**2) - 1

    def constraint_trace(phi):
        phi = phi.reshape((m, 1))
        E = U @ (phi @ phi.conj().T) @ U.conj().T
        val = np.real(np.trace((M_l - M_k) @ E))
        return -val-eps

    def constraint_pad(phi):
        return phi[n:]

    Q = np.real(U.conj().T @ (M_l - M_k) @ U)
    def cons_J(phi):
        # g=-2Q*phi
        return -2*(Q@phi)

    def cons_h(phi, v):
        # g2=-2Q
        return -2*v*Q

    phi0 = rho.copy()

    cons1 = NonlinearConstraint(constraint_norm, 0, 0)
    cons2 = NonlinearConstraint(constraint_trace, 0, np.inf, jac=cons_J, hess=cons_h)
    cons3 = NonlinearConstraint(constraint_pad, 0, 0)
    # constraints = [
    #     {'type': 'eq', 'fun': constraint_norm},
    #     {'type': 'ineq', 'fun': constraint_trace},
    # ]
    cons = [cons1, cons2, cons3] if need_pad else [cons1, cons2]
    res = minimize(objective, phi0, method='trust-constr', constraints=cons, options={'maxiter': 1000})

    if not res.success:
        log(f"Optimization failed: {res.message}")
        return np.inf, None, None

    phi_opt = res.x
    fi = np.abs(np.vdot(rho, phi_opt)) ** 2
    log(f"fidelity = {fi}")
    if fi > 1:
        fi = 1
        log('! This fidelity is larger than 1. Reset to 1.')
    delta_k = 1 - fi
    return delta_k, fi, phi_opt[:n]


class Verifier:
    def __init__(self, model_setting, save_path):
        self.model_n = model_setting['model_n']
        self.class_idx = model_setting['class_idx']
        self.n_qubits = model_setting['n_qubits']
        self.img_shape = model_setting['img_shape']

        save_path = os.path.join(save_path, 'theoretical')
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        self.log = Log(os.path.join(save_path, 'log.txt'))
        self.p_c = save_path

    def init_solver(self, name='QCQP'):
        if name == 'QCQP':
            self.solver = optimal_robust_bound
        elif name == 'SDP':
            self.solver = None

    def preprocess_input(self, inputs, model):
        if self.model_n == 'hqnn':
            inputs = torch.relu(model.cl(inputs))
            return pad_img(inputs.detach().numpy(), 2 ** self.n_qubits)
        elif self.model_n == 'drnn':
            inputs = torch.zeros((inputs.shape[0], 2**self.n_qubits)).numpy()
            inputs[:, 0] = 1.
            return inputs, False
        else:  # amplitude encoding
            return pad_img(inputs.numpy(), 2 ** self.n_qubits)

    def transform_state(self, phi):
        if self.model_n in ['hqnn', 'drnn']:
            return phi, False
        else:
            return state2img(phi, img_shape=self.img_shape), True

    def verify(self, model, m_params, test_x, test_y, solver='QCQP'):
        self.init_solver(name=solver)
        asr = 0
        fidelity_l = []
        start = time.time()
        U = model.circuit2matrix_wo_embed(torch.flatten(test_x[0]), m_params).detach().numpy()
        test_x = torch.flatten(test_x, start_dim=1)
        n = test_x.shape[1]
        test_rho, need_pad = self.preprocess_input(test_x, model)
        for i, rho in enumerate(test_rho):
            self.log(f'------ Start for {i}th input, original label: {test_y[i]}')
            # U = model.circuit2matrix_wo_embed(rho, m_params).numpy()  # todo for drnn
            l = test_y[i]

            delta_l = []
            phi_l = []
            fi_l = []
            M_l = output_proj(self.model_n, l, self.n_qubits, self.class_idx)
            for k in range(0, len(self.class_idx)):
                if k == l: continue
                self.log(f'compute for class {k}...')
                M_k = output_proj(self.model_n, k, self.n_qubits, self.class_idx)
                delta_k, fi, phi_opt = self.solver(rho, M_l, M_k, U, self.log, n, need_pad)
                if fi is None: continue
                delta_l.append(delta_k)
                phi_l.append(phi_opt)
                fi_l.append(fi)
                self.log(f"δ_{k} = {delta_k}")
            if len(delta_l) == 0: continue
            delta, k_ = torch.min(torch.tensor(delta_l)), torch.argmin(torch.tensor(delta_l))

            phi, img_form = self.transform_state(torch.tensor(phi_l[k_]))
            if self.model_n == 'hqnn':
                phi = phi.reshape((U.shape[0], 1)).numpy()
                E = U @ (phi @ phi.conj().T) @ U.conj().T
                y = torch.zeros((len(self.class_idx)))
                for c in range(0, len(self.class_idx)):
                    M_c = output_proj(self.model_n, c, self.n_qubits, self.class_idx)
                    y[c] = np.real(np.trace(M_c @ E))
                print(y)
                y = y.argmax()
            else:
                y = torch.argmax(torch.tensor(model.predict(phi.unsqueeze(0))[0]))
            self.log(f'Predicted as {y}')

            if y != test_y[i]:
                asr += 1
                fidelity_l.append(fi_l[k_])

            if img_form:
                save_image(phi, os.path.join(self.p_c,
                                                 str(i) + '_' + str(test_y[i].item()) + '_' + str(y.item()) + '.png'))
        end = time.time()
        self.log(f"average time: {(end - start)/test_x.shape[0]:.4f} s")
        fidelity_l = torch.tensor(fidelity_l)
        self.log(f'average fidelity: {fidelity_l.mean()}, std: {fidelity_l.std()}')
        asr /= test_x.shape[0]
        self.log(f'ASR: {asr*100}%')