import torch
import wandb


def compute_mean_mad(dataloaders, properties, dataset_name):
    if dataset_name == 'qm9':
        return compute_mean_mad_from_dataloader(dataloaders['train'], properties)
    elif dataset_name == 'qm9_second_half' or dataset_name == 'qm9_second_half':
        return compute_mean_mad_from_dataloader(dataloaders['valid'], properties)
    elif 'zinc250k' in dataset_name or 'guacamol' in dataset_name:
        return compute_mean_mad_from_dataloader(dataloaders['train'], properties)
    else:
        raise Exception('Wrong dataset name')


def compute_mean_mad_from_dataloader(dataloader, properties):
    property_norms = {}
    for property_key in properties:
        values = dataloader.dataset.data[property_key]
        mean = torch.mean(values)
        ma = torch.abs(values - mean)
        mad = torch.mean(ma)
        property_norms[property_key] = {}
        property_norms[property_key]['mean'] = mean
        property_norms[property_key]['mad'] = mad
    return property_norms

def compute_properties_upper_bounds(dataloader, properties, log_to_wandb=True):
    loss_l1 = torch.nn.L1Loss()
    property_upperbounds = {}
    for property_key in properties:
        values = dataloader.dataset.data[property_key]
        maes = []
        # average out bias in random permutation
        for _ in range(10):
            # MAE between random perms of data, our conditional model should do better than this
            perm = torch.randperm(len(values))
            maes.append(loss_l1(values, values[perm]))
        property_upperbounds[property_key] = torch.mean(torch.Tensor(maes))
        if log_to_wandb:
            wandb.log({property_key+'_upperbound': property_upperbounds[property_key]})
    return property_upperbounds

edges_dic = {}
def get_adj_matrix(n_nodes, batch_size, device):
    if n_nodes in edges_dic:
        edges_dic_b = edges_dic[n_nodes]
        if batch_size in edges_dic_b:
            return edges_dic_b[batch_size]
        else:
            # get edges for a single sample
            rows, cols = [], []
            for batch_idx in range(batch_size):
                for i in range(n_nodes):
                    for j in range(n_nodes):
                        rows.append(i + batch_idx*n_nodes)
                        cols.append(j + batch_idx*n_nodes)

    else:
        edges_dic[n_nodes] = {}
        return get_adj_matrix(n_nodes, batch_size, device)


    edges = [torch.LongTensor(rows).to(device), torch.LongTensor(cols).to(device)]
    return edges

def preprocess_input(one_hot, charges, charge_power, charge_scale, device):
    charge_tensor = (charges.unsqueeze(-1) / charge_scale).pow(
        torch.arange(charge_power + 1., device=device, dtype=torch.float32))
    charge_tensor = charge_tensor.view(charges.shape + (1, charge_power + 1))
    atom_scalars = (one_hot.unsqueeze(-1) * charge_tensor).view(charges.shape[:2] + (-1,))
    return atom_scalars


# called during training
def prepare_context(conditioning, minibatch, property_norms, prop_encoder=None, condition_dropout=False, dropout_p=0.):
    batch_size, n_nodes, _ = minibatch['positions'].size()
    node_mask = minibatch['atom_mask'].unsqueeze(2)
    context_node_nf = 0
    context_list = []
    for key in conditioning:
        properties = minibatch[key]
        if property_norms is not None:
            properties = (properties - property_norms[key]['mean']) / property_norms[key]['mad']
        if key == 'morgan_fingerprint':
            # fingerprint features
            assert properties.size() == (batch_size, 1024)
            reshaped = properties.unsqueeze(1).repeat(1, n_nodes, 1)
            context_list.append(reshaped)
            context_node_nf += 1024
        elif len(properties.size()) == 1:
            # Global feature.
            # We are usually here
            assert properties.size() == (batch_size,)
            if prop_encoder is not None:
                # We are using sine and cosine encodings for target properties
                properties = prop_encoder(properties)
                reshaped = properties.view(batch_size, 1, -1).repeat(1, n_nodes, 1)
                context_list.append(reshaped)
                context_node_nf += reshaped.size(2)
            else:
                # mostly here
                if condition_dropout:
                    # contains 1 for the molecules that will not use the condition
                    condition_drop_mask = torch.bernoulli(dropout_p * torch.ones(batch_size, 1))
                    properties = properties.unsqueeze(1)
                    # either (val, 0) when conditioning or (0, 1) when not conditioning
                    condition_vector = torch.cat((properties * (1-condition_drop_mask), condition_drop_mask), dim=1)
                    condition_vector = condition_vector.unsqueeze(1).repeat(1, n_nodes, 1)
                    context_list.append(condition_vector)
                    context_node_nf += 2
                else:
                    reshaped = properties.view(batch_size, 1, 1).repeat(1, n_nodes, 1)
                    context_list.append(reshaped)
                    context_node_nf += 1
        elif len(properties.size()) == 2 or len(properties.size()) == 3:
            # Node feature.
            assert properties.size()[:2] == (batch_size, n_nodes)

            context_key = properties

            # Inflate if necessary.
            if len(properties.size()) == 2:
                context_key = context_key.unsqueeze(2)

            context_list.append(context_key)
            context_node_nf += context_key.size(2)
        else:
            raise ValueError('Invalid tensor size, more than 3 axes.')
    # Concatenate
    context = torch.cat(context_list, dim=2)
    # Mask disabled nodes!
    context = context * node_mask
    assert context.size(2) == context_node_nf
    return context

def prepare_regression_target(regression_target, minibatch, property_norms, normalize=True):
    targets = []
    for prop in regression_target:
        prop_target = minibatch[prop]
        if normalize:
            prop_target = (prop_target - property_norms[prop]['mean']) / property_norms[prop]['mad']
        targets.append(prop_target.unsqueeze(1))
    targets = torch.cat(targets, dim=1)

    n_props = len(regression_target)
    batch_size = minibatch['positions'].size(0)
    assert targets.size() == (batch_size, n_props)
    return targets.squeeze() # in case we have a single prop we want shape (bs,)

def prepare_classification_target(regression_target, minibatch, property_norms):
    targets = minibatch[regression_target]

    batch_size = minibatch['positions'].size(0)
    assert targets.size(0) == batch_size
    return targets

def unnormalize_regression_prediction(regression_target, pred, property_norms):
    """
    pred (batch_size,) or (batch_size, n_props)
    """
    if len(pred.size()) == 1:
        pred = pred.unsqueeze(1)
    
    unnormalized_pred = []
    for i, prop in enumerate(regression_target):
        prop_pred = pred[:,i]
        prop_pred = property_norms[prop]['mad'] * prop_pred + property_norms[prop]['mean']
        unnormalized_pred.append(prop_pred.unsqueeze(1))
    unnormalized_pred = torch.cat(unnormalized_pred, dim=1)
    
    assert unnormalized_pred.size() == pred.size()
    return unnormalized_pred.squeeze()
