from typing import Optional
import numpy as np
import torch as th
from torch.optim import Adam, SGD

import torchquantum as tq
import torchquantum.functional as tqf
from kernels.vgg import VGG
from kernels.base import BQK

class QGK(BQK):

  def __init__(self, eta:int, inputs:int, groups:Optional[int]=None, projection:int=1,
               mode='fidelity', hidden_dims:list[int]=[], optimizer:Optional[str]=None, lr:float=0.1, 
               verbose:bool=False) -> None:
    """Quantum Generator Kernel Args:  
        eta (int): number of qubits
        mode (str): projector | fidelity | simulation (mode of kernel computation)
          fidelity supports batched execution and is executed using einsum
          projector is executed using bmm and might be slower for batched inputs
        inputs (int): input dimension (sigma), defaults to 1/2 of the available groups
        groups (int): number of groups, defaults to inputs
        projection (int): projection width for merging generators \in [1,2eta] (comaparable to stride),
        hidden_dims (list): hidden dimensions for feature extractor, defaults to [], i.e. no feature extractor
        optimizer (str): optimizer for training [Adam, SGD, None], defaults to None
        lr (float): learning rate for training, defaults to 0.1
        verbose (bool): print training progress, defaults to True"""
    
    super().__init__(eta, inputs, mode=mode, verbose=verbose)

    if groups is None and optimizer is None: groups = inputs


    comp_method = 'bmm' if mode == 'projector' else 'einsum'
    self.vgg = VGG(eta, groups=groups, projection=projection, comp_method=comp_method, verbose=verbose)

    fe_dims = [self.inputs, *hidden_dims, self.vgg.groups]
    self.feature_extractor = th.nn.Sequential(*[
        layer for i,o in zip(fe_dims[:-1], fe_dims[1:]) 
      for layer in [th.nn.Linear(i, o, dtype=th.float64), th.nn.Sigmoid()]
    ][:-1]) if optimizer else None

    self.optimizer = eval(optimizer)(lr=lr, params=self.feature_extractor.parameters()) if optimizer else None
    if self.inputs != self.vgg.groups: assert self.optimizer is not None, "Optimizer required for feature extractor"
    self.F = lambda x: self.feature_extractor(x) if optimizer else x


  def process(self, data, inverse=False):
    if not inverse: [self.device.h(i) for i in range(self.eta)]
    self.vgg.sigma = self.F(data if th.is_tensor(data) else th.tensor(data))
    self.vgg(self.device, range(self.eta), inverse=inverse)
    if inverse: [self.device.h(i) for i in range(self.eta)]


  def evolve(self, data):
    """Init device and evolve the data, supporting batched execution in fidelity mode"""
    qc = tq.QuantumDevice(self.eta, bsz=data.shape[0])
    self.vgg.sigma = self.F(data if th.is_tensor(data) else th.tensor(data))
    [tqf.h(qc, i) for i in range(self.eta)]
    self.vgg(qc, range(self.eta))
    return qc.get_states_1d()
  

def kta(K, Y):
    """Kernel-target alignment between kernel and labels."""
    T = th.outer(Y, Y)
    inner_product = th.sum(K * T)
    norm = th.sqrt(th.sum(K * K) * th.sum(T * T))
    return inner_product / norm

if __name__ == '__main__':
  from config import *
  from run.train import fit, train, execute
  # metrics, _,_,_ = execute({**MOONS, **QGK_STATIC(1)}, [42], fit)
  # metrics, _,_,_ 
  metrics, kernel, D, cfg = execute({**MOONS, **QGK(1, 100)}, [42], train)
  # print(f"Final Accuracy: {metrics['Test/Accuracy'][:,-1]}, KTA: {metrics['Test/KTA'][:,-1]}")

