import torch
import numpy as np
import sklearn


class RunningAverage:
    def __init__(self) -> None:
        self.current_value = 0 # value of running average
        self.n = 0 # number of elements so far

    def update(self, x):
        self.n += 1
        self.current_value += (x - self.current_value) / self.n

    def get(self):
        return self.current_value

def test_running_average():
    for _ in range(10):
        l = np.random.rand(10)
        running_average = RunningAverage()
        for x in l:
            running_average.update(x)
        assert np.abs(running_average.get() - np.mean(l)) <= 1e-10

# batchified version
def get_adj_matrix_from_batch(batch):
    bs, n_nodes, _ = batch['positions'].shape
    adj_matrix = torch.zeros((bs, n_nodes, n_nodes))
    adj_list = batch['adj_list']

    # boolean mask for presence/absence of edge
    edge_mask = adj_list.sum(-1) != 0

    # indices of edges in the batch dimension: e.g. for batch 0 as many 0s as there are edges
    d1 = [i for i, n in enumerate(edge_mask.sum(-1)) for _ in range(n.item())]
    # d2 and d3 are just the actual indices from the adj_list
    d2 = adj_list[:,:,0][edge_mask].int().tolist()
    d3 = adj_list[:,:,1][edge_mask].int().tolist()
    # the edge types stored in the thrid column of edge_list
    target = adj_list[:,:,2][edge_mask]

    # write the edges
    adj_matrix[[d1, d2, d3]] = target.float()
    # replace aromatic bonds with type 4
    adj_matrix[adj_matrix==1.5] = 4
    # make int
    adj_matrix = adj_matrix.int()
    # symmetrize
    adj_matrix += adj_matrix.clone().transpose(2,1)
    return adj_matrix.long()

def get_adj_matrix_from_batch_slow(batch):
    bs, n_nodes, _ = batch['positions'].shape

    adj_matrix = torch.zeros((bs, n_nodes, n_nodes), dtype=int)
    for batch_idx, adj_list in enumerate(batch['adj_list']):
        for i, j, bond_order in adj_list:
            if bond_order != 0:
                i, j = i.int(), j.int()
                if bond_order == 1.5: # aromatic bonds
                    adj_matrix[batch_idx, i, j] = 4
                    adj_matrix[batch_idx, j, i] = 4
                else:
                    adj_matrix[batch_idx, i, j] = bond_order
                    adj_matrix[batch_idx, j, i] = bond_order
    assert (adj_matrix == torch.transpose(adj_matrix, 1, 2)).all()
    return adj_matrix

    # if we ever need one_hot encoded gt
    # bond_types = torch.Tensor([0, 1, 2, 3]).unsqueeze(0).unsqueeze(0).unsqueeze(0)
    # adj_matrix_one_hot = (adj_matrix.unsqueeze(-1) == bond_types).int()
    # assert (adj_matrix == torch.argmax(adj_matrix_one_hot, dim=-1)).all()
    # adj_matrix = adj_matrix_one_hot


def compute_class_weight(loader, dataset_info, recompute_class_weight=False):
    weight_atom_types, weight_formal_charges = compute_class_weight_atom_types_and_formal_charges(loader, dataset_info, recompute_class_weight=recompute_class_weight)
    weight_edges = compute_class_weight_edges(loader, dataset_name=dataset_info['name'], recompute_class_weight=recompute_class_weight)
    return {'atom_types': weight_atom_types, 'formal_charges': weight_formal_charges, 'edges': weight_edges}


def compute_class_weight_atom_types_and_formal_charges(loader, dataset_info, recompute_class_weight=False):
    """
    (Cached Version)
    """
    dataset_name=dataset_info['name']

    if dataset_name == 'zinc250k' and not recompute_class_weight:
        weight_atom_types = torch.Tensor([1.5080e-01, 9.1038e-01, 1.1136e+00, 8.0893e+00, 4.8796e+03, 6.2405e+00, 1.4963e+01, 5.0355e+01, 7.1199e+02])
        weight_formal_charges = torch.Tensor([73.0149,  0.3394, 25.0233])
        return weight_atom_types, weight_formal_charges
    elif dataset_name == 'zinc250k_explicitH' and not recompute_class_weight:
        weight_atom_types = torch.Tensor([5.1298e+00, 1.3842e-01, 8.3563e-01, 1.0222e+00, 7.4252e+00, 4.4790e+03, 5.7281e+00, 1.3735e+01, 4.6220e+01, 6.5354e+02])
        weight_formal_charges = torch.Tensor([74.4665,  0.3393, 25.5208])
        return weight_atom_types, weight_formal_charges

    n_atom_types = len(dataset_info['atom_decoder'])
    atom_types_count = dict(zip(range(n_atom_types), [0]*n_atom_types))

    possible_formal_charges = torch.unique(loader.data['formal_charges'], sorted=True) # most often will be [-1, 0, 1]
    possible_formal_charges = possible_formal_charges.int().tolist()
    formal_charges_count = dict(zip(possible_formal_charges, [0]*len(possible_formal_charges)))

    for batch in loader:
        for atoms, formal_charges, n_nodes in zip(batch['atomic_numbers_one_hot'], batch['formal_charges'], batch['num_atoms']):
            atoms = atoms[:n_nodes,].float().argmax(-1)
            formal_charges = formal_charges[:n_nodes]

            for atom in atoms:
                atom_types_count[atom.item()] += 1

            for formal_charge in formal_charges:
                formal_charges_count[formal_charge.item()] += 1

    y_atom_types = []
    for key, value in atom_types_count.items():
        y_atom_types.extend([key]*value)
    weight_atom_types = sklearn.utils.class_weight.compute_class_weight(class_weight='balanced', classes=range(n_atom_types), y=y_atom_types)
    weight_atom_types = torch.Tensor(weight_atom_types)

    y_formal_charges = []
    for key, value in formal_charges_count.items():
        y_formal_charges.extend([key]*value)
    weight_formal_charges = sklearn.utils.class_weight.compute_class_weight(class_weight='balanced', classes=possible_formal_charges, y=y_formal_charges)
    weight_formal_charges = torch.Tensor(weight_formal_charges)

    return weight_atom_types, weight_formal_charges


def compute_class_weight_edges(loader, dataset_name=None, recompute_class_weight=False):
    """
    (Cached Version)
    Estimate class weights for unbalanced datasets. 
    Very important for the task of edge prediction since moelcules are usually sparsly connected
    Uses scikit-learn algorithm https://scikit-learn.org/stable/modules/generated/sklearn.utils.class_weight.compute_class_weight.html
    Args:
        loader: train data loader

    Returns:
        weight (torch.Tensor): to be given as an argument to the cross_entropy loss function
    """
    if dataset_name == 'qm9' and not recompute_class_weight:
        return torch.Tensor([0.2752,  0.9033, 11.3114, 24.8436,  7.6710])
    elif dataset_name == 'zinc250k' and not recompute_class_weight:
        return torch.Tensor([2.0982e-01, 5.8085e+00, 1.2263e+02, 3.1189e+03, 1.8729e+01])
    elif dataset_name == 'zinc250k_explicitH' and not recompute_class_weight:
        return torch.Tensor([2.2019e-01, 4.1518e+00, 3.4856e+01, 8.8651e+02, 5.3236e+00])

    # dict that maps the bond order to its frequency in the training dataset
    # 4 is the aromatic type
    bond_stats = {'0': 0, '1': 0, '2': 0, '3': 0, '4': 0}

    for batch in loader:
        adj_gt = get_adj_matrix_from_batch(batch)
        for gt, n_nodes in zip(adj_gt, batch['num_atoms']):
            # remove padded zeros, only account for `actual` edges
            gt = gt[:n_nodes, :n_nodes]
            
            # by duplicates we mean the symmetrical edges arising from the undirected nature of the molecular graphs
            bond_stats['0'] += ((gt==0).sum() - n_nodes) / 2 # remove diagonal (always 0 by construction) and remove duplicates
            bond_stats['1'] += (gt==1).sum() / 2 # remove duplicates
            bond_stats['2'] += (gt==2).sum() / 2 # remove duplicates
            bond_stats['3'] += (gt==3).sum() / 2 # remove duplicates
            bond_stats['4'] += (gt==4).sum() / 2 # remove duplicates. aromatic bonds will be assigned class 4

    # the float tensors are now integers
    bond_stats = {key: int(item.item()) for key, item in bond_stats.items()}
    # scikit-learn function expects vector of elements
    zeros = np.zeros(bond_stats['0'])
    ones = np.ones(bond_stats['1'])
    twos = np.ones(bond_stats['2']) * 2
    threes = np.ones(bond_stats['3']) * 3
    fours = np.ones(bond_stats['4']) * 4

    y = np.concatenate((zeros, ones, twos, threes, fours))

    weight = sklearn.utils.class_weight.compute_class_weight(class_weight='balanced', classes=[0,1,2,3,4], y=y)
    return torch.Tensor(weight)
