from typing import Optional
import numpy as np
import torch as th

# from qiskit import QuantumCircuit 
# from qiskit.quantum_info import state_fidelity, Statevector
import torchquantum as tq
import torchquantum.functional as tqf

from math import sqrt 
from scipy.special import factorial
from torch import matrix_exp as expm

import time

from kernels.base import BQK
from torch.optim import Adam, SGD


class QEK(BQK):
  def __init__(self, eta:int, inputs:int, mode='fidelity', reupload:int=0,
               optimizer:str='SGD', lr:float=0.1, verbose:bool=False) -> None:
    """Args:  
      eta (int): number of qubits
      mode (str): projector | fidelity | simulation (mode of kernel computation)
        fidelity supports batched execution and is executed using einsum
        projector is executed using bmm and might be slower for batched inputs
      inputs (int): input dimension (sigma), defaults to 1/2 of the available groups
      reupload (int): reupload the data n times, defaults to 0 (num_layers = inputs//eta*(reupload+1))
      optimizer (str): optimizer for training [Adam, SGD, None], defaults to None
      lr (float): learning rate for training, defaults to 0.1
      verbose (bool): print training progress, defaults to True"""
    if eta < 2: eta = 2; print("QEK requires at least 2 qubits")
    super().__init__(eta, inputs, mode=mode, verbose=verbose)
    # assert num_layers == inputs//eta, "Number of layers must be equal to inputs/eta"
    self.num_layers = int(np.ceil(inputs/eta) * (reupload + 1))
    self.params = th.nn.Parameter(th.FloatTensor(self.num_layers, 2, self.eta).uniform_(0, 2 * np.pi))
    self.optimizer = eval(optimizer)(lr=lr, params=self.parameters())


  def process(self, data, inverse=False):
    for i in reversed(range(len(self.params))) if inverse else range(len(self.params)):
        params = self.params[i]
        if not inverse:
            # Apply forward layer
            [tqf.h(self.device, j) for j in range(self.eta)]
            [tqf.rz(self.device, j, data[(i * self.eta + j) % self.inputs]) for j in range(self.eta)]
            [tqf.ry(self.device, j, params[0, j]) for j in range(self.eta)]
            [tqf.crz(self.device, [j, (j + 1) % self.eta], params[1, j]) for j in range(self.eta)]
        else:
            # Apply inverse of the layer (in reverse gate order)
            [tqf.crz(self.device, [j, (j + 1) % self.eta], -params[1, j]) for j in reversed(range(self.eta))]
            [tqf.ry(self.device, j, -params[0, j]) for j in reversed(range(self.eta))]
            [tqf.rz(self.device, j, -data[(i * self.eta + j) % self.inputs]) for j in reversed(range(self.eta))]
            [tqf.h(self.device, j) for j in reversed(range(self.eta))]  # h is self-inverse       


  def evolve(self, data):
    """Init device and evolve the data, supporting batched execution in fidelity mode"""
    qc = tq.QuantumDevice(self.eta, bsz=data.shape[0])
    for i, params in enumerate(self.params):
      [tqf.h(qc, j) for j in range(self.eta)]
      [tqf.rz(qc, j, data[:, (i*self.eta + j) % self.inputs]) for j in range(self.eta)]
      [tqf.ry(qc, j, params[0, j]) for j in range(self.eta)]
      [tqf.crz(qc, [j, (j+1) % self.eta], params[1, j]) for j in range(self.eta)]

    return qc.get_states_1d()


if __name__ == '__main__':
  from config import *
  from run.train import fit, execute

  eta = 2; seed = 42; 
  metrics, kernel, D, cfg = execute({**MOONS, **QEK(eta)}, [seed], fit)
