from typing import Optional
import warnings
import torch as th
import numpy as np
import torchquantum as tq

from torchquantum.plugin.qiskit import op_history2qiskit
from qiskit import transpile
from tqdm import tqdm
import io,contextlib
from torch.optim import Adam, SGD


class BQK(tq.QuantumModule):

  def __init__(self, eta:int, inputs:int, mode='fidelity', intial_state:float=0.0, verbose:bool=True) -> None:
    """BaseQuantumKernel Class Args:  
      eta (int): number of qubits
      mode (str): projector | fidelity | simulation (mode of kernel computation)
        fidelity supports batched execution and is executed using einsum
        projector is executed using bmm and might be slower for batched inputs
      inputs (int): input dimension (sigma), defaults to 1/2 of the available groups
      initial_state (th.tensor): initial state of the quantum device, defaults to ground state
      verbose (bool): print training progress, defaults to True"""
    super().__init__()
    self.eta = eta
    self.device = tq.QuantumDevice(n_wires=eta)
    self.inital_state = th.tensor([[intial_state]* 2 ** self.eta], dtype=th.complex64)
    self.verbose = verbose
    self.mode = mode
    self.inputs = inputs
    self.calls = 0
    self.backend = None
    self.depths = []
    self.shots = 256


  def process(self, data, inverse=False):
    """Process the supplied data using self.device when using project mode"""
    raise NotImplementedError("Please Implement this method")
  

  def evolve(self, data):
    """Init device and evolve the data, supporting batched execution in fidelity mode"""
    raise NotImplementedError("Please Implement this method")
  

  def project(self, a, b):
    self.device = tq.QuantumDevice(n_wires=self.eta, record_op=True)
    self.process(a); self.process(b, inverse=True)
    state = self.device.get_states_1d().view(-1)
    return th.abs(state[0]**2)
      

  def kernel(self, A:th.tensor, B:th.tensor, dry_run:bool=False):
    """Compute the kernel of batch inputs A and B with dim [batchsize, inputs]"""
    if self.mode == 'projector': 
      return th.tensor([[self.project(a,b) for b in B] for a in A])

    elif self.mode == 'fidelity': 
      return th.abs(th.matmul(
          self.evolve(A), 
          self.evolve(B).conj().transpose(0, 1)
        )) ** 2
    elif self.mode == 'simulation': 
      assert self.backend is not None, "Simulation mode requires a backend"
      result = th.zeros((A.shape[0], B.shape[0]), dtype=th.float32)
      if not dry_run: A = tqdm(A, desc="Computing kernel matrix")
      for i, a in enumerate(A):
        for j, b in enumerate(B):
          self.device = tq.QuantumDevice(n_wires=self.eta, record_op=True)
          self.process(a); self.process(b, inverse=True)
          with contextlib.redirect_stdout(io.StringIO()):
            with warnings.catch_warnings():
              warnings.simplefilter("ignore", RuntimeWarning)
              circ = op_history2qiskit(self.eta, self.device.op_history); circ.measure_all()
              # transpiled_circ = transpile(circ, backend=self.backend, optimization_level=0) # Layout / Routing
              transpiled_circ = transpile(circ, backend=self.backend, optimization_level=1) # Inverse + 1Q Gates
              # transpiled_circ = transpile(circ, backend=self.backend, optimization_level=2) # Commutative
            self.depths.append(transpiled_circ.depth())
          if not dry_run:
            job = self.backend.run(transpiled_circ, shots=self.shots)
            counts = job.result().get_counts()
            result[i, j] = counts.get("0" * self.eta, 0) / self.shots
      if dry_run: return self.depths
      return result
    else: assert False, f'{self.mode} not supported'


  def forward(self, A:th.tensor, B:Optional[th.tensor]=None): 
    self.calls += 1
    return self.kernel(A, A if B is None else B)
  

  def train(self, X:th.tensor, Y:th.tensor, loss:callable, epochs:int=100, 
            batch_size:Optional[int]=None,  callback:Optional[callable]=None, 
            log_interval:int=20):
    """ Train the kernel using the supplied data and loss function. Args:
      X, Y (th.tensor): input and  data
      loss (callable): loss function
      epochs (int): number of training epochs
      batch_size (int): batch size for training
      callback (callable): callback function for training
      log_interval (int): log interval for training progress"""
    assert self.optimizer, "Optimizer not set"; losses = []
    for ep in range(epochs):
      self.optimizer.zero_grad()
      if batch_size is not None: 
        batch_idx = np.random.choice(list(range(X.size(0))), batch_size)
        X, Y = X[batch_idx], Y[batch_idx]
      losses.append(loss(self, X, Y))
      losses[-1].backward()
      self.optimizer.step()
      if (ep+1)%log_interval==0: 
        if callback is not None: 
          callback(self, losses[-1], ep+1)
        if self.verbose: print(f"KTA {-losses[-1]:.3f} @Step {ep+1}")
    return losses


class BCK(th.nn.Module):
    """A torch-based kernel base class with optional trainable preprocessing."""
    
    def __init__(self, inputs: int, hidden_dims: list[int] = [], compression: float = 1.0,
                 optimizer: Optional[str] = None, lr: float = 0.01, verbose: bool = True):
        """Kernel Args:
        inputs (int): input dimension
        hidden_dims (list): hidden dimensions for preprocessing network, defaults to []
        compression (float): compression factor for feature extractor, defaults to 1.0
        optimizer (str): optimizer for training [Adam, SGD, None], defaults to None
        lr (float): learning rate for training, defaults to 0.01
        verbose (bool): print training progress, defaults to True
        """
        super().__init__()
        self.inputs = inputs
        self.compression = compression
        self.optimizer = optimizer
        self.verbose = verbose
        self.calls = 0
        
        # Create preprocessing network
        if self.compression == 1.0: self.feature_extractor = th.nn.Identity()
        else: 
            fe_dims = [self.inputs, *hidden_dims, int(self.inputs / self.compression)]
            self.feature_extractor = th.nn.Sequential(*[
                layer for i, o in zip(fe_dims[:-1], fe_dims[1:]) 
                for layer in [th.nn.Linear(i, o, dtype=th.float64), th.nn.Tanh()]
            ][:-1])  # Remove last activation
            assert self.optimizer is not None, "Optimizer required for feature extractor"
            self.optimizer = eval(optimizer)(self.feature_extractor.parameters(), lr=lr)
        self.F = lambda x: x if self.compression == 1.0 else self.feature_extractor(x)
        

    def kernel(self, A: th.tensor, B: th.tensor):
       raise NotImplementedError("Please Implement this method")
        

    def forward(self, A: th.tensor, B: Optional[th.tensor] = None):
        """Forward pass through the kernel."""
        self.calls += 1
        return self.kernel(self.F(A), self.F(A if B is None else B))
    
    def train(self, X: th.tensor, Y: th.tensor, loss: callable, epochs: int = 100,
                    batch_size: Optional[int] = None, callback: Optional[callable] = None,
                    log_interval: int = 20):
        """Train the kernel using the supplied data and loss function."""
        assert self.optimizer is not None, "Optimizer not set"
        losses = []
        
        for ep in range(epochs):
            self.optimizer.zero_grad()
            
            if batch_size is not None:
                batch_idx = np.random.choice(list(range(X.size(0))), batch_size)
                X_batch, Y_batch = X[batch_idx], Y[batch_idx]
            else: X_batch, Y_batch = X, Y
            
            loss_val = loss(self, X_batch, Y_batch)
            losses.append(loss_val)
            loss_val.backward()
            self.optimizer.step()
            
            if (ep + 1) % log_interval == 0:
                if callback is not None: callback(self, loss_val, ep + 1)
                if self.verbose: print(f"Loss {loss_val:.6f} @Step {ep + 1}, Gamma: {self.gamma.item():.6f}")   
        return losses

