from typing import Optional
import torch as th
from torch.optim import Adam, SGD

import torchquantum as tq
import torchquantum.functional as tqf
from kernels.base import BQK


class HEE(BQK):

  def __init__(self, eta:int, inputs:int, layers:int=1, mode='fidelity', 
               hidden_dims:list[int]=[], optimizer:Optional[str]=None, lr:float=0.1, 
               verbose:bool=False) -> None:
    """Hardware Efficient Embedding  
        eta (int): number of qubits
        inputs (int): input dimension (sigma), defaults to 1/2 of the available groups
        layers (int): number of rotation layers, defaults to 1
        mode (str): projector | fidelity | simulation (mode of kernel computation)
          fidelity supports batched execution and is executed using einsum
          projector is executed using bmm and might be slower for batched inputs
        hidden_dims (list): hidden dimensions for feature extractor, defaults to [], i.e. single linear layer
        optimizer (str): optimizer for training [Adam, SGD, None], defaults to None, i.e. no feature extractor
        lr (float): learning rate for training, defaults to 0.1

        verbose (bool): print training progress, defaults to True"""
    if eta < 2: eta = 2; print("HEE requires at least 2 qubits")
    super().__init__(eta, inputs, mode=mode, verbose=verbose); self.layers = layers
    fe_dims = [self.inputs, *hidden_dims, eta]
    self.feature_extractor = th.nn.Sequential(*[
        layer for i,o in zip(fe_dims[:-1], fe_dims[1:]) 
      for layer in [th.nn.Linear(i, o, dtype=th.float64), th.nn.Sigmoid()]
    ][:-1]) if optimizer is not None else None # simple linear fe if hidden_dims is empty

    self.optimizer = eval(optimizer)(lr=lr, params=self.feature_extractor.parameters()) if optimizer is not None else None
    self.F = lambda x: self.feature_extractor(x) if self.feature_extractor is not None else x


  def process(self, data, inverse=False):
    operations = [
      (range(self.eta), lambda j: self.device.rx(j, (-1 if inverse else 1)*data[j])),
      (range(self.eta - 1), lambda j: self.device.cx([j, j + 1])),
    ]
    for i in reversed(range(self.layers)) if inverse else range(self.layers):
       for r, op in reversed(operations) if inverse else operations:
        for j in r: op(j)


  def evolve(self, data):
    """Init device and evolve the data, supporting batched execution in fidelity mode"""
    data = self.F(data if th.is_tensor(data) else th.tensor(data))
    qc = tq.QuantumDevice(self.eta, bsz=data.shape[0])
    for i in range(self.layers):
      [tqf.rx(qc, j, data[:, j]) for j in range(self.eta)]
      [tqf.cx(qc, [j, j+1]) for j in range(self.eta-1)]
    return qc.get_states_1d()


if __name__ == '__main__':
  from config import *
  from run.train import fit, train, execute

  eta = 2; seed = 42; steps = 100 

  # PCA Preprocessing
  metrics, kernel, D, cfg = execute({**MOONS, **HEE(eta)}, [seed], fit, add_loss=True, pca_features=eta)

  # # Linear Feature Extractor
  metrics, _,_,_ = execute({**MOONS, **HEE(eta, steps)}, [seed], train)
