import re

import pennylane as qml
import torch
import torch.nn as nn
import torch.nn.functional as F
from pennylane import numpy as np
from qiskit import Aer, QuantumCircuit, execute, transpile
from qiskit.quantum_info import Statevector

from utils import get_matrix, replace, to_qiskit


class FidLossDotProd(nn.Module):
    """Just a wrapper to make sure the loss func is differentiable"""

    def __init__(self, qnode: nn.Module, **kwargs):
        """
        Args:
            qnode: the PQC that has been transformed to torch.nn.Module
        """
        super().__init__()
        self._qnode = qnode
        self._kwargs = kwargs

    def forward(self, params, target_state):
        return dot_product_loss(params, self._qnode, target_state, **self._kwargs)

class FidLossDotProdAAE(nn.Module):
    def __init__(self, is_noisy=False) -> None:
        super().__init__()
        self.is_noisy = is_noisy

    def forward(self, result_state, target_state):  # for aae encoder training
        bsz = target_state.shape[0]
        if not self.is_noisy:
            assert result_state.shape == target_state.shape

            # Calculate batch dot product: result_state -> (bsz, 1, N), target_state -> (bsz, N, 1)
            dot_product = torch.bmm(
                result_state.view(bsz, 1, -1), target_state.view(bsz, -1, 1)
            )

            # Remove last two dimensions
            dot_product = dot_product.squeeze(-1).squeeze(-1)

            loss = 1 - dot_product.abs() ** 2
        else:
            # first get: item = \langle\psi | \rho |\psi\rangle
            item = torch.bmm(target_state.view(bsz, 1, -1), result_state.view((bsz, ) + result_state.shape))
            item = torch.bmm(item, target_state.view(bsz, -1, 1))
            # loss = 1 - \langle\psi | \rho |\psi\rangle
            loss = 1 - item

        return loss


# class CircMatrixLayer(nn.Module):
#     """Transform qml.matrix to torch Module"""
#
#     def __init__(self, qnode):
#         super().__init__()


class FidLossMSE(nn.Module):
    # FIXME: here kwargs is just for compatibility with DotProd loss
    # TODO: create a common base class??
    def __init__(self, qnode: nn.Module, **kwargs):
        """
        Args:
            qnode: the PQC that has been transformed to torch.nn.Module
        """
        super().__init__()
        self._qnode = qnode
        self._kwargs = kwargs

    # FIXME: pay attention to this override
    # ---def forward(self, result_state, target_state):
    # ---    return F.mse_loss(result_state, target_state)

    def forward(self, params, target_state):
        result_state = self._qnode(params).real.to(torch.float32)
        assert result_state.shape == target_state.shape
        return F.mse_loss(result_state, target_state)


def dot_product_loss(
    params, qnode, target_state, avg: bool = False, noisy: bool = False
):
    """Batch mode of https://github.com/Zhaoyilunnn/qenc/blob/master/train_encoders.py#L38

    Args:
        result_state: (bsz, N)
        target_state: (bsz, N)
        avg: use loss.mean() or loss.sum()
        noisy: whether qc return density matrix
    """
    from models.superencoders import mat_fn

    result_state = qnode(params).real.to(torch.float32)
    bsz = result_state.shape[0]

    if not noisy:
        assert result_state.shape == target_state.shape

        # Calculate batch dot product: result_state -> (bsz, 1, N), target_state -> (bsz, N, 1)
        dot_product = torch.bmm(
            result_state.view(bsz, 1, -1), target_state.view(bsz, -1, 1)
        )

        # Remove last two dimensions
        dot_product = dot_product.squeeze(-1).squeeze(-1)

        loss = 1 - dot_product.abs() ** 2
    else:
        # first get: item = \langle\psi | \rho |\psi\rangle
        item = torch.bmm(target_state.view(bsz, 1, -1), result_state)
        item = torch.bmm(item, target_state.view(bsz, -1, 1))
        # loss = 1 - \langle\psi | \rho |\psi\rangle
        loss = 1 - item

    if avg:
        return loss.mean()
    return loss.sum()
