from functools import lru_cache

import numpy as np
import pyscf
from pyscf.data.elements import (
    NUC as element_to_atomic_number,
    ELEMENTS as atomic_number_to_element,
)

import torch
from e3nn import o3

def build_irreps_from_mol(mol):
    parity = lambda l: 'e' if l % 2 == 0 else 'o'
    irreps = o3.Irreps(
        '+'.join( f'{l}{parity(l)}' for l in map(mol.bas_angular, range(mol.nbas)))
    )
    return irreps


def ref_etb(aobasis, beta, atom_type_names):
    charges = [element_to_atomic_number[element] for element in atom_type_names]
    spin = sum(charges) % 2
    coords = np.zeros((len(charges), 3))
    reference_mol = pyscf.M(atom=list(zip(charges, coords)), basis=aobasis, spin=spin)
    basis = pyscf.df.aug_etb(reference_mol, beta)
    return basis


def parse_basis_name(basis_name, atom_type_names):
    if basis_name.startswith('etb:'):
        # example: etb:def2-svp:1.5
        _, aobasis, beta = basis_name.split(':')
        return ref_etb(aobasis, float(beta), atom_type_names)
    else:
        return basis_name


def per_element_irreps_from_basis_name(basis_name, elements):
    basis = parse_basis_name(basis_name, elements)
    per_element_irreps = {}
    if isinstance(basis, str):
        for e in elements:
            z = element_to_atomic_number[e]
            if z % 2 != 0:
                spin = 1
            else:
                spin = 0
            mol = pyscf.M(atom=[[z, [0, 0, 0]]], basis=basis_name, charge=0, spin=spin)
            per_element_irreps[e] = build_irreps_from_mol(mol)
        return per_element_irreps
    else:
        parity = lambda l: 'e' if l % 2 == 0 else 'o'
        irrep_from_item = lambda item: f'1x{item[0]}{parity(item[0])}'
        for e in elements:
            per_element_irreps[e] = o3.Irreps('+'.join(irrep_from_item(i) for i in basis[e]))
        return per_element_irreps


@lru_cache(maxsize=16)
def pyscf_to_standard_perm_D_for_single_irrep(l, p, dtype=torch.float32):
    if l == 1:
        return torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=dtype)
    else:
        return torch.eye(2 * l + 1, dtype=dtype)


@lru_cache(maxsize=32)
def e3nn_change_of_coord_D_for_single_irrep(l, p, dtype=torch.float32):
    cod = torch.tensor([
        # this specifies the change of basis yzx -> xyz
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.]
    ], dtype=dtype)
    irreps = o3.Irreps([(1, (l, p))])
    return irreps.D_from_matrix(cod)


@lru_cache(maxsize=32)
def pyscf_to_e3nn_D_for_single_irrep(l, p, dtype=torch.float32):
    pyscf_to_std = pyscf_to_standard_perm_D_for_single_irrep(l, p, dtype=dtype)
    cod_D = e3nn_change_of_coord_D_for_single_irrep(l, p, dtype=dtype)
    return pyscf_to_std @ cod_D


def e3nn_change_of_coord_D(irreps, dtype=torch.float32):
    perm = torch.block_diag(*[
        e3nn_change_of_coord_D_for_single_irrep(orbital.ir.l, orbital.ir.p, dtype=dtype)
        for orbital in irreps
    ])
    return perm


def pyscf_to_standard_perm_D(atomic_orbital_irreps, dtype=torch.float32):
    perm = torch.block_diag(*[
        pyscf_to_standard_perm_D_for_single_irrep(orbital.ir.l, orbital.ir.p, dtype=dtype)
        for orbital in atomic_orbital_irreps
    ])
    return perm.to(dtype=dtype)


def get_pyscf_to_e3nn_D(atomic_orbital_irreps, dtype=torch.float32):
    total_D = torch.block_diag(*[
        pyscf_to_e3nn_D_for_single_irrep(orbital.ir.l, orbital.ir.p, dtype=dtype)
        for orbital in atomic_orbital_irreps
    ])
    return total_D


def transform_from_pyscf_to_std_fast_1d(irreps, array: np.ndarray):
    # only the l=1 orbitals need to be transformed: xyz -> yzx
    assert array.shape[0] == irreps.dim
    ret = array.copy()
    start = 0
    for orbital in irreps:
        if orbital.ir.l == 1:
            for _ in range(orbital.mul):
                ret[start:start+3] = np.roll(ret[start:start+3], -1, axis=0)
                start += 3
        else:
            start += orbital.dim
    return ret


def transform_from_std_to_pyscf_fast_1d(irreps, array: np.ndarray):
    # only the l=1 orbitals need to be transformed: xyz -> yzx
    assert array.shape[0] == irreps.dim
    ret = array.copy()
    start = 0
    for orbital in irreps:
        if orbital.ir.l == 1:
            for _ in range(orbital.mul):
                ret[start:start+3] = np.roll(ret[start:start+3], 1, axis=0)
                start += 3
        else:
            start += orbital.dim
    return ret


def transform_from_pyscf_to_std_fast_2d(irreps, matrix: np.ndarray):
    assert matrix.shape[0] == matrix.shape[1] == irreps.dim
    ret = matrix.copy()
    start = 0
    for orbital in irreps:
        if orbital.ir.l == 1:
            for _ in range(orbital.mul):
                ret[start:start+3, :] = np.roll(ret[start:start+3, :], -1, axis=0)
                ret[:, start:start+3] = np.roll(ret[:, start:start+3], -1, axis=1)
                start += 3
        else:
            start += orbital.dim
    return ret


def transform_from_std_to_pyscf_fast_2d(irreps, matrix: np.ndarray):
    assert matrix.shape[0] == matrix.shape[1] == irreps.dim
    ret = matrix.copy()
    start = 0
    for orbital in irreps:
        if orbital.ir.l == 1:
            for _ in range(orbital.mul):
                ret[start:start+3, :] = np.roll(ret[start:start+3, :], 1, axis=0)
                ret[:, start:start+3] = np.roll(ret[:, start:start+3], 1, axis=1)
                start += 3
        else:
            start += orbital.dim
    return ret


class GTOBasis:
    per_element_irreps: dict[str, o3.Irreps]
    per_element_numel: dict[str, int]
    allowed_elements: list[str]
    allowed_atomic_numbers: list[int]

    def __init__(self, per_element_irreps: dict[str, str]):
        self.per_element_irreps = {}
        self.per_element_numel = {}
        self.allowed_elements = []
        self.allowed_atomic_numbers = []

        for elem, irstr in per_element_irreps.items():
            irreps = o3.Irreps(irstr)
            self.per_element_irreps[elem] = irreps
            self.per_element_numel[elem] = irreps.dim
            self.allowed_elements.append(elem)
            self.allowed_atomic_numbers.append(element_to_atomic_number[elem])

    @classmethod
    def from_basis_name(cls, basis_name, elements: list[str]):
        per_element_irreps = per_element_irreps_from_basis_name(basis_name, elements)
        return cls(per_element_irreps=per_element_irreps)

    def irreps_for_mol(self, atom_types: list[int]):
        atom_elements = [atomic_number_to_element[z] for z in atom_types]
        atomic_orbital_irreps = o3.Irreps(
            '+'.join(str(self.per_element_irreps[e]) for e in atom_elements)
        )
        return atomic_orbital_irreps


class GTOProductBasisHelper:
    basis: GTOBasis
    padded_irrep: o3.Irreps
    per_element_mask: dict[str, np.ndarray]

    def __init__(self, basis: GTOBasis):
        self.basis = basis

        maxl = max(basis.per_element_irreps[e].lmax for e in basis.allowed_elements)
        parity = lambda l: 'e' if l % 2 == 0 else 'o'
        ls = list(range(maxl + 1))
        l_irreps = [f'{l}{parity(l)}' for l in ls]

        per_element_irrep_mul = {
            e: list(basis.per_element_irreps[e].count(l_ir) for l_ir in l_irreps)
            for e in basis.allowed_elements
        }

        max_mul = list(max(per_element_irrep_mul[e][l] for e in basis.allowed_elements) for l in ls)
        max_irreps = '+'.join(
            f'{mul}x{l_ir}'
            for mul, l_ir in zip(max_mul, l_irreps)
        )
        self.padded_irrep = o3.Irreps(max_irreps)

        self.per_element_mask = {}
        for e in self.basis.allowed_elements:
            self.per_element_mask[e] = np.zeros(self.padded_irrep.dim, dtype=bool)
            element_irreps = basis.per_element_irreps[e]
            for l, l_ir, lslice in zip(ls, l_irreps, self.padded_irrep.slices()):
                mul = element_irreps.count(l_ir)
                l_dim = 2 * l + 1
                valid_slice = slice(lslice.start, lslice.start + mul * l_dim)
                self.per_element_mask[e][valid_slice] = True

    def transform_from_pyscf_to_std(self, atom_types: list[int], matrix: np.ndarray):
        # only the l=1 orbitals need to be transformed: xyz -> yzx
        atomic_orbital_irreps = self.basis.irreps_for_mol(atom_types)
        return transform_from_pyscf_to_std_fast_2d(atomic_orbital_irreps, matrix)

    def transform_from_std_to_pyscf(self, atom_types: list[int], matrix: np.ndarray):
        # only the l=1 orbitals need to be transformed: xyz -> yzx
        atomic_orbital_irreps = self.basis.irreps_for_mol(atom_types)
        return transform_from_std_to_pyscf_fast_2d(atomic_orbital_irreps, matrix)

    def split_matrix_to_padded_blocks(
        self,
        atom_types: list[int],
        matrix: np.ndarray,
    ) -> tuple[np.ndarray, np.ndarray]:
        atom_types = np.array(atom_types, dtype=np.int32)
        natom = len(atom_types)
        atom_elements = [atomic_number_to_element[z] for z in atom_types]
        atom_dims = [self.basis.per_element_numel[e] for e in atom_elements]
        atom_starts = [0] + np.cumsum(atom_dims).tolist()[:-1]

        diag_blocks = np.zeros((natom, self.padded_irrep.dim, self.padded_irrep.dim), dtype=matrix.dtype)
        diag_masks = np.zeros((natom, self.padded_irrep.dim, self.padded_irrep.dim), dtype=bool)
        tril_blocks = np.zeros((natom * (natom - 1) // 2, self.padded_irrep.dim, self.padded_irrep.dim), dtype=matrix.dtype)
        tril_masks = np.zeros((natom * (natom - 1) // 2, self.padded_irrep.dim, self.padded_irrep.dim), dtype=bool)
        tril_edge_index = np.zeros((2, (natom * (natom - 1) // 2)), dtype=np.int64)

        for i in range(natom):
            mask_i = self.per_element_mask[atom_elements[i]]
            slice_i = slice(atom_starts[i], atom_starts[i] + atom_dims[i])
            block = matrix[slice_i, slice_i]
            diag_masks[i] = mask_i[:, None] & mask_i[None, :]
            diag_blocks[i, diag_masks[i]] = block.reshape(-1)

        iblock = 0
        for i in range(0, natom):
            mask_i = self.per_element_mask[atom_elements[i]]
            slice_i = slice(atom_starts[i], atom_starts[i] + atom_dims[i])
            for j in range(0, i):  # lower triangle
                mask_j = self.per_element_mask[atom_elements[j]]
                slice_j = slice(atom_starts[j], atom_starts[j] + atom_dims[j])
                tril_masks[iblock, mask_i[:, None] & mask_j[None, :]] = True
                tril_blocks[iblock, tril_masks[iblock]] = matrix[slice_i, slice_j].reshape(-1)
                tril_edge_index[0, iblock] = i
                tril_edge_index[1, iblock] = j
                iblock += 1

        return diag_blocks, diag_masks, tril_blocks, tril_masks, tril_edge_index

    def assemble_matrix_from_padded_blocks(
        self,
        atom_types: list[int],
        padded_diag_blocks: list[np.ndarray],
        padded_tril_blocks: list[np.ndarray],  # in lower triangular order
    ) -> np.ndarray:
        atom_types = np.array(atom_types, dtype=np.int32)
        natom = len(atom_types)
        atom_elements = [atomic_number_to_element[z] for z in atom_types]
        atom_dims = [self.basis.per_element_numel[e] for e in atom_elements]
        atom_starts = [0] + np.cumsum(atom_dims).tolist()[:-1]

        nao = sum(self.basis.per_element_numel[e] for e in atom_elements)
        matrix = np.zeros((nao, nao), dtype=padded_diag_blocks[0].dtype)

        for i in range(natom):
            slice_i = slice(atom_starts[i], atom_starts[i] + atom_dims[i])
            mask_i = self.per_element_mask[atom_elements[i]]
            block = padded_diag_blocks[i][mask_i][:, mask_i]
            matrix[slice_i, slice_i] = block

        iblock = 0
        for i in range(0, natom):
            mask_i = self.per_element_mask[atom_elements[i]]
            slice_i = slice(atom_starts[i], atom_starts[i] + atom_dims[i])
            for j in range(0, i):  # lower triangle
                mask_j = self.per_element_mask[atom_elements[j]]
                slice_j = slice(atom_starts[j], atom_starts[j] + atom_dims[j])
                block = padded_tril_blocks[iblock][mask_i][:, mask_j]
                matrix[slice_i, slice_j] = block
                matrix[slice_j, slice_i] = block.T
                iblock += 1

        return matrix


class GTOAuxDensityHelper:
    atom_types: list[int]
    basis: GTOBasis
    atomic_orbital_irreps: o3.Irreps
    atomic_orbital_masks_by_element: dict[str, np.ndarray]
    atom_indices_by_element: dict[str, np.ndarray]

    def __init__(self, atom_types: list[int], basis: GTOBasis):
        self.atom_types = np.array(atom_types, dtype=np.int32)
        atom_elements = [atomic_number_to_element[z] for z in self.atom_types]
        self.basis = basis
        self.atomic_orbital_irreps = self.basis.irreps_for_mol(self.atom_types)

        mask = np.concatenate([[z] * basis.per_element_numel[e] for e, z in zip(atom_elements, self.atom_types)])
        self.atomic_orbital_masks_by_element = {
            e: mask == z
            for e, z in zip(basis.allowed_elements, basis.allowed_atomic_numbers)
        }
        self.atom_indices_by_element = {
            e: np.nonzero(atom_types == z)[0]
            for e, z in zip(basis.allowed_elements, basis.allowed_atomic_numbers)
        }

    def transform_from_pyscf_to_std(self, array: np.ndarray):
        return transform_from_pyscf_to_std_fast_1d(self.atomic_orbital_irreps, array)

    def transform_from_std_to_pyscf(self, array: np.ndarray):
        return transform_from_std_to_pyscf_fast_1d(self.atomic_orbital_irreps, array)

    def split_ao_by_elements(self, array: np.ndarray):
        splitted = {
            e: array[self.atomic_orbital_masks_by_element[e]].reshape(-1, self.basis.per_element_numel[e])
            for e in self.basis.allowed_elements
        }
        assert sum(len(v) for v in splitted.values()) == len(self.atom_types)
        assert sum(v.size for v in splitted.values()) == self.atomic_orbital_irreps.dim
        return splitted

    def assemble_ao_from_per_element_arrays(self, splitted: dict[str, np.ndarray]):
        dtype = next(iter(splitted.values())).dtype
        array = np.zeros(self.atomic_orbital_irreps.dim, dtype=dtype)
        for z, mask in self.atomic_orbital_masks_by_element.items():
            array[mask] = splitted[z].reshape(-1)
        return array
