# -*- coding: utf-8 -*-


from itertools import product
import torch
import torch.nn as nn
import torch.optim as optim
import torch.sparse as sparse
import numpy as np
from qiskit.quantum_info import Pauli, state_fidelity, SparsePauliOp
import matplotlib.pyplot as plt
from scipy.sparse.linalg import eigs
#%%
import numpy as np 
import math
import torch
from torch.optim.optimizer import Optimizer
class Frankenstein (Optimizer):
    r"""Implements Frankenstein optimizer
    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 1e-3)
        fixed_beta (float, optional):  when fixed_beta!=0, the beta 
            is performed as a constant value
            when when fixed_beta==0, the beta depend on learning rate
            automatically (default: 0)
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        weight_decouple (boolean, optional): ( default: True) If set as True, then
            the optimizer uses decoupled weight decay as in AdamW
    """
    def __init__(self, params, lr=1e-3, eps=1e-8,
                 weight_decay=0, weight_decouple=True, fixed_beta=0,base_lr=1e-3,base_beta=0.9):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= fixed_beta < 1.0:
            raise ValueError("Invalid beta_1 value: {}".format(fixed_beta))
        defaults = dict(lr=lr, eps=eps, weight_decay=weight_decay,
                        weight_decouple=weight_decouple,
                        fixed_beta=fixed_beta,base_lr=base_lr,base_beta=base_beta)

        super(Frankenstein, self).__init__(params, defaults)
        
        self.max_rho=float(np.exp(1.03))
        self.min_rho=float(np.exp(-0.2))
        self.max_beta_adj=float(0.05)
        self.pi=float(math.pi)
        
        
        
    def __setstate__(self, state):
        super(Frankenstein, self).__setstate__(state)
        
    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError(
                        'Frankenstein does not support sparse gradients, please consider SparseAdam instead')
                state = self.state[p]
                if len(state) == 0:
                    state['m'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state['s'] = torch.mul(torch.ones_like(p, memory_format=torch.preserve_format),group['lr'])
                    state['vmax'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                m, s,vmax = state['m'], state['s'],state['vmax']
                
                
                if group['fixed_beta']!=0:
                    beta_1=group['fixed_beta']
                else:
                    
                    beta_1=1.0- max(self.max_beta_adj, min(1-self.max_beta_adj, (1-group['base_beta']) * math.sqrt(group['lr'] / group['base_lr'])))
                
                if group['weight_decay'] > 0:
                    if group['weight_decouple']:
                        p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
                    else:
                        grad.add_(p.data, alpha=group['weight_decay'])
                
                
                p_factor=torch.div(torch.acos(torch.tanh(torch.mul(m,grad))),self.pi)  # frankenstein
                xi =torch.div(1.60653065971,torch.add(1.0,torch.exp(-torch.abs(torch.add(s ,-p_factor)))))

                x_t=torch.add(torch.mul(grad,grad) ,group['eps'])
                v_t=torch.max(vmax, x_t)
                sqrt_v=torch.sqrt(v_t)
                alpha_xi_sqrt_v=torch.mul(torch.div(group['lr'],sqrt_v),xi)
                rho_factor=torch.log(torch.clamp(3.21828182846-p_factor+sqrt_v, min=self.min_rho,max=self.max_rho))
                m.mul_(torch.mul(rho_factor,beta_1)).add_(torch.mul(-grad , alpha_xi_sqrt_v))                 # Momentum update
                beta_2=torch.mul(torch.clamp(torch.div(x_t,s),0.0,1.0),torch.abs(p_factor-0.5))  # actually, 1-beta_2
                p.data.add_(torch.add(torch.mul(beta_1,m),torch.mul(-grad, alpha_xi_sqrt_v))) # Parameter update
                vmax.copy_(torch.add(torch.mul(v_t,torch.add(1.0,-beta_2)),torch.mul(beta_2,x_t)))  # v_t update
                s.copy_(x_t)
        return loss
def str2WeightedPaulis(s):
	s = s.strip()
	IXYZ = ['I', 'X', 'Y', 'Z']
	prev_idx = 0
	coefs = []
	paulis = []
	is_coef = True
	for idx, c in enumerate(s + '+'):
		if idx == 0: continue
		if is_coef and c in IXYZ:
			coef = complex(s[prev_idx : idx].replace('i', 'j'))
			coefs.append(coef)
			is_coef = False
			prev_idx = idx
		if not is_coef and c in ['+', '-']:
			label = s[prev_idx : idx]
			paulis.append(Pauli(label))
			is_coef = True
			prev_idx = idx
	return SparsePauliOp(paulis,coefs)

def tran(tin):
    if tin=='I':
        return '0'
    if tin=='X':
        return '1'
    if tin=='Y':
        return '2'
    if tin=='Z':
        return '3'
    else:
        return tin
    
def J1J2(na,nb,J1,J2):
    Sl=[]
    Hl=[]
    Sb=[['X','X'],['Y','Y'],['Z','Z']]
    # 0  1  2  3
    # 4  5  6  7
    # 8  9 10 11
    #12 13 14 15
    #J1 periodic
    edge=[]
    edges=[]
    n=na*nb
    for i in range(na):
        for j in range(nb):
            a=i*nb+j
            b=i*nb+np.mod(j+1,nb)
            if a!=b:
                edge.append([a,b])
    for i in range(na):
        for j in range(nb):
            a=i*nb+j
            b=np.mod((i+1),na)*nb+j
            if a!=b:
                edge.append([a,b])
            
    # for i in range(n):
    for i in edge:
        edges.append(i)
        for j in Sb:
            init=list('I'*n)
            tmp=[]
            # a=np.mod(i,n)
            # b=np.mod(i+1,n)
            a=i[0]
            b=i[1]
            init[a]=j[0]
            init[b]=j[1]
            Sl.append(str(J1)+''.join(init))
            for k in list(init):
                tmp.append(tran(k))
            Hl.append([J1,tmp])
    #J2 periodic
    edge=[]
    n=na*nb
    for i in range(na):
        for j in range(nb):
            a=i*nb+j
            b=i*nb+np.mod(j+2,nb)
            if a!=b and [b,a] not in edge:
                edge.append([a,b])
    
    for i in range(na):
        for j in range(nb):
            a=i*nb+j
            b=np.mod((i+2),na)*nb+j
            if a!=b and [b,a] not in edge:
                edge.append([a,b])
            
    for i in edge:
        edges.append(i)
        for j in Sb:
            init=list('I'*n)
            # a=np.mod(i,n)
            # b=np.mod(i+2,n)
            a=i[0]
            b=i[1]
            init[a]=j[0]
            init[b]=j[1]
            Sl.append(str(J2)+''.join(init))
            
    a='1'+'I'*n
    H=0*str2WeightedPaulis(a).to_matrix(sparse=1)
    for i in Sl:
        H+=str2WeightedPaulis(i).to_matrix(sparse=1)
    return H

#%%
# Create a random sparse matrix H in LIL format (you can replace this with your actual matrix)
# Assume H is of size M x M
A = 4  # Example size, adjust as needed
B = 3
j1 = 1
j2 = 0.5
H = J1J2(A, B, 1, j1*j2)
# M = 2**(A*B)
Q=A*B
u,w=eigs(H,which='SR',k=3)
gs=min(u)
# Convert CSR matrix to COO format (required for PyTorch sparse tensors)
H_coo = torch.tensor(H.tocoo().data, dtype=torch.float32)
row_indices = torch.tensor(H.tocoo().row, dtype=torch.int64)
col_indices = torch.tensor(H.tocoo().col, dtype=torch.int64)

# Create a neural network to generate v (complex-valued)
class RealLinear(nn.Module):
    def __init__(self, layer_sizes):
        super(RealLinear, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(in_size, out_size) for in_size, out_size in layer_sizes])

    def forward(self, x):
        for layer in self.layers[:-1]:
            x = layer(x)
            x = torch.relu(x)
            
        x = self.layers[-1](x)
        return x/torch.norm(x)


# Example usage:
    
log=[]
for seed in range(32):
    E1=[]
    layer_sizes = [(Q, Q*10), (Q*10, Q*10), (Q*10, Q*4), (Q*4, 1)] 
    torch.manual_seed(seed)
    net = RealLinear(layer_sizes).cuda()
    # Define an optimizer (e.g., stochastic gradient descent)
    # optimizer = optim.SGD(net.parameters(), lr=0.01)
    optimizer = optim.Adam(net.parameters(), lr=1e-3)
    #optimizer = Frankenstein(net.parameters(), lr=1e-3)
    #optimizer = optim.SGD(net.parameters(), lr=0.01,momentum=0.9,nesterov=True)
    
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10000)
    
    # Training loop (you can adjust the number of iterations)
    # v = torch.randn(M, requires_grad=True)  # Initialize v randomly
    v = torch.tensor(list(product(*[[0, 1] for _ in range(Q)])),dtype=torch.float32).cuda()
    
    
    for epoch in range(10000):
        optimizer.zero_grad()
        v_nn = net(v)  # Generate v using the neural network
        v_nn=v_nn.reshape(-1)
        # Custom loss function
        sparse_H = torch.sparse_coo_tensor(torch.stack([row_indices, col_indices]), H_coo, size=(2**Q, 2**Q)).cuda()
        result = torch.sparse.mm(sparse_H, v_nn.view(-1, 1)).squeeze()
        loss = torch.dot(v_nn, result)  # Constraint: v_nn.conj().dot(result) = 1
        E1.append(loss.item()-min(u.real))
        loss.backward()
        optimizer.step()
        scheduler.step()
        if epoch%1000==0:
            print(f"Epoch {epoch+1}, Loss: {loss.item()-min(u.real)}")
    log.append(E1)

np.save('adam1',np.array(log))