import numpy as np
import torch
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.rdForceFieldHelpers import MMFFOptimizeMolecule
from typing import Tuple, List, Dict


DEFAULT_SEEDS = [
    16880611,
    17760704,
    17890714,
    19491001,
    19900612,
]


def set_seed(seed: int, use_cuda=False):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if use_cuda:
        torch.cuda.manual_seed(seed)


def split_by_interval(n, i, given_list: list = None):
    temp = 0
    ret = []
    if given_list:
        while temp + i < n:
            ret.append([given_list[j] for j in range(temp, temp + i)])
            temp += i
        ret.append([given_list[j] for j in range(temp, n)])
    else:
        while temp + i < n:
            ret.append(list(range(temp, temp + i)))
            temp += i
        ret.append(list(range(temp, n)))
    return ret


def get_mean_std(array: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    mean = torch.mean(array, dim=0, keepdim=True)
    std = torch.std(array - mean, dim=0, keepdim=True)
    return mean, std


def dict_cuda_copy(d: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    d1 = {}
    for k, v in d.items():
        d1[k] = v.cuda()
    return d1


def dict_list_cuda_copy(d: Dict[str, List[torch.Tensor]]) -> Dict[str, List[torch.Tensor]]:
    d1 = {}
    for k, v in d.items():
        d1[k] = [vi.cuda() for vi in v]
    return d1


def optimize_pos(pos: torch.Tensor, smiles: str) -> torch.Tensor:
    pos_ = pos.cpu().detach().numpy()
    try:
        mol = Chem.MolFromSmiles(smiles)
        AllChem.EmbedMolecule(mol)
        for i, p in enumerate(pos_):
            mol.GetConformer().SetAtomPosition(i, [float(p[0]), float(p[1]), float(p[2])])
        MMFFOptimizeMolecule(mol)
        pos_ = mol.GetConformer().GetPositions()
    except ValueError:
        return pos
    return torch.FloatTensor(pos_).type_as(pos)


if __name__ == '__main__':
    # print(split_by_interval(4, 1))
    # print(split_by_interval(12, 4))
    # print(split_by_interval(13, 4))
    # print(split_by_interval(15, 4))
    # a = torch.FloatTensor([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
    # mean, std = get_mean_std(a)
    # print(mean)
    # print(std)
    # print((a - mean) / std)

    pos_ftr = torch.FloatTensor([
        [1, 0, 0],
        [0.5, 0.85, 0],
        [-0.5, 0.85, 0],
        [-1, 0, 0],
        [-0.5, -0.99, 0],
        [0.5, -0.85, 0],
    ])
    pos_ftr = optimize_pos(pos_ftr, 'c1ccccc1')
    print(pos_ftr)
