from typing import Optional
import numpy as np
import torch as th
import torchquantum as tq
import torchquantum.functional as tqf

from math import sqrt 
from scipy.special import factorial

import time


class VGG(tq.QuantumModule):
  """Adapted from torchquantum.operator OpHamilExpMatrix"""

  def __init__(self, eta:int, groups:Optional[int]=None, projection:int=1, 
               comp_method:str='bmm', fast_exp:bool=True, verbose:bool=False) -> None:
    """Variational Generator Group, merging available generators for eta qubits into groups Args: 
        eta (int): number of qubits
        groups (int): number of groups
        projection (int): projection width for merging generators \in [1,2eta] (comaparable to stride),
        projection (str): projection method for merging generators ['wide' | 'narrow']
        comp_method (str): computation method for matrix multiplication ['bmm' | 'einsum']
        fast_exp (bool):  use fast matrix exponential computation (i.e., e^∑-i*h_i*p_i over ∏e^-i*h_i*p_i) (default: True)"""
    super().__init__(); start = time.time()
    self.fast_exp = fast_exp
    self.comp_method = comp_method
    self.projection = projection

    # Determine dimension of the Hilbert space for eta qubits and number of generators
    self.eta = eta; _H = 2 ** self.eta; self.generators = _H ** 2 - 1
    g = lambda n: g(n-1) * 2 + (n%2!=0) * 2 - 1 if n > 1 else 1
    groups = groups or int(self.generators / g(eta))
    assert groups <= self.generators; self.groups = groups

    # Build Hamiltonian Matrix H and init input parameter sigma
    self.sigma = th.ones(self.groups) # Input Data / Context 
    self.H = self.build_groups(projection)

    if verbose: print(f"Merged {self.generators} generators into {self.groups} groups for {self.eta} qubits in {time.time()-start}s")


  def build_groups(self, projection:int):
    """Build the VGGs from the generators. """
    H = th.zeros(1, self.groups, 2 ** self.eta, 2 ** self.eta, dtype=th.complex128) 
    
    # Construct the set of generators for the VGGs
    norm = lambda i: sqrt(factorial(i)); _H = 2 ** self.eta
    G = [
      *[[((i,j),(j,i)), (1+0j, 1+0j)] for i in range(_H - 1) for j in range(i + 1, _H)],
      *[[((i,j),(j,i)), (0+1j, 0-1j)] for i in range(_H - 1) for j in range(i + 1, _H)],
      *[[2*((*(range(i+1)),),), (*(i)*(1/norm(i),), *[-i/norm(i)])] for i in range(1, _H)]
    ]

    w = 2**projection
    idx = [g % self.generators for g in range(0, self.generators*w,w)]
    
    # Merge the generators into groups and assert that the generated H is hermitian 
    for g in range(self.groups): 
      for n,i in enumerate(idx):
        if n%self.groups==g: 
          H[0][g][G[i][0]] += th.tensor(G[i][1], dtype=th.complex128)
    assert all([th.allclose(h, th.conj_physical(h.T)) for h in H[0]])  
    return H


  @property
  def phi(self): return self.sigma.view(*self.sigma.shape,1,1)


  @property
  def exponent_matrix(self): 
    matrix = self.H * -1j * self.phi
    if self.fast_exp: return matrix.sum(dim=1)
    return matrix
  

  @property 
  def matrix(self): return th.matrix_exp(self.exponent_matrix)
  

  def forward(self, qdev, wires, inverse=False):
    """Forward the OpHamilExp module. Args:
      qdev: The QuantumDevice.
      wires: The wires. """
    if self.fast_exp: 
      tqf.qubitunitaryfast(qdev, wires, params=self.matrix.to(qdev.device), 
                           inverse=inverse, comp_method=self.comp_method)
    else: 
      [tqf.qubitunitaryfast(qdev, wires, params=matrix.to(qdev.device), 
                            inverse=inverse, comp_method=self.comp_method) 
        for matrix in (self.matrix.flip(dims=(0,)) if inverse else self.matrix)]
