import torch
import numpy as np
from sklearn.utils import shuffle
from torch_geometric.data import DataLoader, Data
from torch_scatter import scatter
from torch_sparse import SparseTensor
from math import sqrt, pi as PI

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



def load_md17(dataset, train_size, valid_size):
    data = np.load('datasets/' + dataset + '_dft.npz')
    E = data['E']
    F = data['F']
    R = data['R']
    z = data['z']
    num_atom = len(z)
    num_molecule = len(E)
    dataset = []
    for i in range(num_molecule):
        R_i = torch.tensor(R[i],dtype=torch.float32)
        z_i = torch.tensor(z,dtype=torch.int64)
        E_i = torch.tensor(E[i],dtype=torch.float32)
        F_i = torch.tensor(F[i],dtype=torch.float32)
        data = Data(pos=R_i, z=z_i, y=E_i, force=F_i) #edge_index
        dataset.append(data)
    ids = shuffle(range(num_molecule), random_state=42)
    train_idx, val_idx, test_idx = np.array(ids[:train_size]), np.array(ids[train_size:train_size + valid_size]), np.array(ids[train_size + valid_size:])

    train_dataset = [dataset[int(i)] for i in train_idx]
    val_dataset = [dataset[int(i)] for i in val_idx]
    test_dataset = [dataset[int(i)] for i in test_idx]
    return train_dataset, val_dataset, test_dataset, num_atom


def load_qm9(dataset, target, train_size, valid_size):
    data = np.load('datasets/' + dataset + '_eV.npz')
    R = data['R']
    Z = data['Z']
    N= data['N']
    split = np.cumsum(N)
    R_qm9 = np.split(R, split)
    Z_qm9 = np.split(Z,split)
    y = np.expand_dims(data[target],axis=-1)
    num_molecule = len(y)
    dataset = []
    for i in range(num_molecule):
        R_i = torch.tensor(R_qm9[i],dtype=torch.float32)
        z_i = torch.tensor(Z_qm9[i],dtype=torch.int64)
        y_i = torch.tensor(y[i],dtype=torch.float32)
        data = Data(pos=R_i, z=z_i, y=y_i) #edge_index
        dataset.append(data)
    ids = shuffle(range(num_molecule), random_state=42)
    train_idx, val_idx, test_idx = np.array(ids[:train_size]), np.array(ids[train_size:train_size + valid_size]), np.array(ids[train_size + valid_size:])

    train_dataset = [dataset[int(i)] for i in train_idx]
    val_dataset = [dataset[int(i)] for i in val_idx]
    test_dataset = [dataset[int(i)] for i in test_idx]
    return train_dataset, val_dataset, test_dataset



def xyztodat(pos, edge_index, num_nodes):
    j, i = edge_index  # j->i

    # Calculate distances. # number of edges
    dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt()

    value = torch.arange(j.size(0), device=j.device)
    adj_t = SparseTensor(row=i, col=j, value=value, sparse_sizes=(num_nodes, num_nodes))
    adj_t_row = adj_t[j]
    num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)

    # Node indices (k->j->i) for triplets.
    idx_i = i.repeat_interleave(num_triplets)
    idx_j = j.repeat_interleave(num_triplets)
    idx_k = adj_t_row.storage.col()
    mask = idx_i != idx_k
    idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask]

    # Edge indices (k-j, j->i) for triplets.
    idx_kj = adj_t_row.storage.value()[mask]
    idx_ji = adj_t_row.storage.row()[mask]

    # Calculate angles. 0 to pi
    pos_ji = pos[idx_i] - pos[idx_j]
    pos_jk = pos[idx_k] - pos[idx_j]
    a = (pos_ji * pos_jk).sum(dim=-1) # cos_angle * |pos_ji| * |pos_jk|
    b = torch.cross(pos_ji, pos_jk).norm(dim=-1) # sin_angle * |pos_ji| * |pos_jk|
    angle = torch.atan2(b, a)
            
    idx_batch = torch.arange(len(idx_i),device=device)
    idx_k_n = adj_t[idx_j].storage.col()
    repeat = num_triplets - 1
    num_triplets_t = num_triplets.repeat_interleave(repeat)
    idx_i_t = idx_i.repeat_interleave(num_triplets_t)
    idx_j_t = idx_j.repeat_interleave(num_triplets_t)
    idx_k_t = idx_k.repeat_interleave(num_triplets_t)
    idx_batch_t = idx_batch.repeat_interleave(num_triplets_t)
    mask = idx_i_t != idx_k_n       
    idx_i_t, idx_j_t, idx_k_t, idx_k_n, idx_batch_t = idx_i_t[mask], idx_j_t[mask], idx_k_t[mask], idx_k_n[mask], idx_batch_t[mask]

    # Calculate torsions.
    pos_j0 = pos[idx_k_t] - pos[idx_j_t]
    pos_ji = pos[idx_i_t] - pos[idx_j_t]
    pos_jk = pos[idx_k_n] - pos[idx_j_t]
    dist_ji = pos_ji.pow(2).sum(dim=-1).sqrt()
    plane1 = torch.cross(pos_ji, pos_j0)
    plane2 = torch.cross(pos_ji, pos_jk)
    a = (plane1 * plane2).sum(dim=-1) # cos_angle * |plane1| * |plane2|
    b = (torch.cross(plane1, plane2) * pos_ji).sum(dim=-1) / dist_ji
    torsion1 = torch.atan2(b, a) # -pi to pi
    torsion1[torsion1<=0]+=2*PI # 0 to 2pi
    torsion = scatter(torsion1,idx_batch_t,reduce='min')

    return dist, angle, torsion, i, j, idx_kj, idx_ji
