import numpy as np
import torch

from typing import Literal

from qtorch import CTYPE

data_dict = {
    ('H2',2,'jw'): 'examples/qas/vqe/mol_data/H2_2q_geom_H_.0_.0_.0;_H_.0_.0_0.7414_jordan_wigner.npz',
    ('H2',3,'jw'): 'examples/qas/vqe/mol_data/H2_3q_geom_H_.0_.0_.0;_H_.0_.0_0.7414_jordan_wigner.npz',
    ('H2',4,'jw'): 'examples/qas/vqe/mol_data/H2_4q_geom_H_.0_.0_+.35;_H_.0_.0_-.35_jordan_wigner.npz',
    ('LiH',4,'parity'): 'examples/qas/vqe/mol_data/LiH_4q_geom_Li_.0_.0_.0;_H_.0_.0_3.4_parity.npz',
    ('LiH',6,'jw'): 'examples/qas/vqe/mol_data/LiH_6q_geom_Li_.0_.0_.0;_H_.0_.0_2.2_jordan_wigner.npz',
    ('H2O',8,'jw'): 'examples/qas/vqe/mol_data/H2O_8q_geom_H_-0.021,_-0.002,_.0;_O_0.835,_0.452,_0;_H_1.477,_-0.273,_0_jordan_wigner.npz'
}
'''
    Dictionary containing the file paths to the corresponding Hamiltonian data 
    files.
    
    Molecular Hamiltonian data obtained from anonymous repository for the CRLQAS
    paper (https://openreview.net/forum?id=rINBD8jPoP), available at the following 
    link: https://anonymous.4open.science/r/CRLQAS/README.md

    The folder `qulacs_noiseless/mol_data/` was copied to our repository in 
    `examples/qas/vqe/mol_data/`. The dictionary only points to the files
    corresponding to the molecular geometries mentioned in Appendix L of the 
    paper.
'''

def load_hamiltonian_data(molecule:Literal['H2','LiH','H2O'],
                          num_qubits:int,
                          transform:Literal['jw','parity']
                          )->tuple[torch.Tensor, float]:
    
    file = data_dict[(molecule,num_qubits,transform)]
    ham_data = np.load(file)

    ham_matrix = ham_data['hamiltonian']
    energy_shift = ham_data['energy_shift']
    ham_matrix += energy_shift * np.eye(2**num_qubits)
    min_eig = np.min(ham_data['eigvals']) + energy_shift

    return torch.tensor(ham_matrix,dtype=CTYPE), min_eig.item()

def electron_count(molecule:Literal['H2','LiH','H2O'],
                   num_qubits:int)->int:
    if molecule == 'H2':
        return 2
    elif molecule == 'LiH':
        if num_qubits == 4:
            return 2
        elif num_qubits == 6:
            return 4
    elif molecule == 'H2O':
        return 8

def hartree_fock_state(num_qubits:int, 
                       num_electrons:int, 
                       transform:Literal['jw','parity'],
                       density_matrix:bool=False)->torch.Tensor:
    if transform == 'jw':
        idx = ((1<<num_electrons) - 1) << (num_qubits-num_electrons)
    elif transform == 'parity':
        s_id = ''
        bit = 1
        for i in range(num_qubits):
            s_id += str(bit)
            if i < num_electrons-1:
                bit = 1-bit
        idx = int(s_id,2)
    else:
        raise ValueError('`transform` must be either "jw" or "parity"!')
    if not density_matrix:
        psi_hf = torch.zeros(2**num_qubits, dtype=CTYPE)
        psi_hf[idx] = 1.0
        return psi_hf
    else:
        rho_hf = torch.zeros([2**num_qubits,2**num_qubits],dtype=CTYPE)
        rho_hf[idx,idx] = 1.0
        return rho_hf

if __name__ == '__main__':
    import os
    # print(os.path.exists(data_dict[('H2',4,'jw')]))
    # print(load_hamiltonian_data('H2', 4, 'jw'))
    print(hartree_fock_state(5,2,'jw').nonzero())