import copy

from bond_type_prediction.utils import get_adj_matrix_from_batch
from bond_type_prediction.losses import adjacency_matrix_loss, atom_types_and_formal_charges_loss
from bond_type_prediction.eval import reconstruction_error

import torch
import wandb
import numpy as np


def train_epoch(args, model, model_ema, ema, epoch, train_loader, optimizer, lr_scheduler, device, class_weight_dict=None, logger='wandb'):
    if lr_scheduler is not None:
        lr_scheduler.step(epoch)
    model.train()
    res = {'epoch': epoch, 'loss': 0, 
           'correct_edges': 0, 'total_edges': 0, 
           'correct_atom_types': 0, 'correct_formal_charges': 0, 'total_atoms': 0, 
           'correct_molecules': 0, 'total_molecules': 0}

    for batch_idx, batch in enumerate(train_loader):
        for key in batch:
            batch[key] = batch[key].to(device)

        if args.train_on_noisy_graphs:
            input_batch = perturb_batch(batch, args.sigma, args.p, device)
        else:
            input_batch = batch

        adj_gt = get_adj_matrix_from_batch(batch).to(device)
        optimizer.zero_grad()

        if args.modify_h:
            # predict edges and atoms
            adj_pred, h_pred = model(input_batch)
            atom_types_loss, formal_charges_loss = atom_types_and_formal_charges_loss(h_pred, batch['one_hot'], batch['charges'], weight_dict=class_weight_dict)
            adj_loss = adjacency_matrix_loss(adj_pred, adj_gt, weight=class_weight_dict['edges'])
            loss = atom_types_loss + formal_charges_loss + adj_loss
        else:
            # predict edges
            adj_pred = model(input_batch)
            # forward atoms
            h_pred = [batch['one_hot'], batch['charges']] # TODO: fix if needed: need to make charges one-hot and cat them along dim=2
            loss = adjacency_matrix_loss(adj_pred, adj_gt, weight=class_weight_dict['edges'])    

        loss.backward()
        optimizer.step()

        # Update EMA if enabled.
        if args.ema_decay > 0:
            ema.update_model_average(model_ema, model)

        # account for last batch with potentially different size
        current_bs = adj_gt.shape[0]
        res['loss'] += loss.item() * current_bs
        if np.isnan(res['loss']):
            print('Detected loss = nan. Stopping Training loop')
            break

        results_batch = reconstruction_error(adj_pred, adj_gt, h_pred, batch)

        res['correct_edges'] += results_batch['correct_edges']
        res['total_edges'] += results_batch['total_edges']

        res['correct_atom_types'] += results_batch['correct_atom_types']
        res['correct_formal_charges'] += results_batch['correct_formal_charges']
        res['total_atoms'] += results_batch['total_atoms']

        res['correct_molecules'] += results_batch['correct_molecules']
        res['total_molecules'] += results_batch['total_molecules']

        if batch_idx != 0 and args.log_interval is not None and batch_idx % args.log_interval == 0:
            print('===> Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * current_bs, len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item()))

    res['avg_loss'] = res['loss'] / res['total_molecules']
    res['edge_accuracy'] = (res['correct_edges'] / res['total_edges']) * 100
    res['atom_types_accuracy'] = (res['correct_atom_types'] / res['total_atoms']) * 100
    res['formal_charges_accuracy'] = (res['correct_formal_charges'] / res['total_atoms']) * 100
    res['molecule_accuracy'] = (res['correct_molecules'] / res['total_molecules']) * 100

    if logger == 'wandb':
        for metric in ['avg_loss', 'edge_accuracy', 'atom_types_accuracy', 'formal_charges_accuracy', 'molecule_accuracy']:
            wandb.log({f'train_{metric}': res[metric]})

    print('Epoch %i Train avg loss: %.4f \nedge_accuracy: %.4f \natom_types_accuracy: %.4f \nformal_charges_accuracy: %.4f \nmolecule_accuracy: %.4f' % (epoch, res['avg_loss'], res['edge_accuracy'], res['atom_types_accuracy'], res['formal_charges_accuracy'], res['molecule_accuracy']))
    return res


def test_epoch(args, model, epoch, test_loader, device, class_weight_dict=None, test_or_val='val', logger='wandb', test_on_noisy_graphs=False):
    model.eval()
    res = {'epoch': epoch, 'loss': 0, 
           'correct_edges': 0, 'total_edges': 0, 
           'correct_atom_types': 0, 'correct_formal_charges': 0, 'total_atoms': 0, 
           'correct_molecules': 0, 'total_molecules': 0}
    with torch.no_grad():
        for idx, batch in enumerate(test_loader):
            for key in batch:
                batch[key] = batch[key].to(device)

            adj_gt = get_adj_matrix_from_batch(batch).to(device)

            if test_on_noisy_graphs:
                input_batch = perturb_batch(batch, args.sigma, args.p, device)
            else:
                input_batch = batch

            if args.modify_h:
                # predict edges and atoms
                adj_pred, h_pred = model(input_batch)
                atom_types_loss, formal_charges_loss = atom_types_and_formal_charges_loss(h_pred, batch['one_hot'], batch['charges'], weight_dict=class_weight_dict)
                adj_loss = adjacency_matrix_loss(adj_pred, adj_gt, weight=class_weight_dict['edges'])
                loss = atom_types_loss + formal_charges_loss + adj_loss
            else:
                # predict edges
                adj_pred = model(input_batch)
                # forward atoms
                h_pred = [batch['one_hot'], batch['charges']] # TODO: fix if needed: need to make charges one-hot and cat them along dim=2
                loss = adjacency_matrix_loss(adj_pred, adj_gt, weight=class_weight_dict['edges'])    

            # account for last batch with potentially different size
            current_bs = adj_gt.shape[0]
            res['loss'] += loss.item() * current_bs

            results_batch = reconstruction_error(adj_pred, adj_gt, h_pred, batch)

            res['correct_edges'] += results_batch['correct_edges']
            res['total_edges'] += results_batch['total_edges']

            res['correct_atom_types'] += results_batch['correct_atom_types']
            res['correct_formal_charges'] += results_batch['correct_formal_charges']
            res['total_atoms'] += results_batch['total_atoms']

            res['correct_molecules'] += results_batch['correct_molecules']
            res['total_molecules'] += results_batch['total_molecules']

    res['avg_loss'] = res['loss'] / res['total_molecules']
    res['edge_accuracy'] = (res['correct_edges'] / res['total_edges']) * 100
    res['atom_types_accuracy'] = (res['correct_atom_types'] / res['total_atoms']) * 100
    res['formal_charges_accuracy'] = (res['correct_formal_charges'] / res['total_atoms']) * 100
    res['molecule_accuracy'] = (res['correct_molecules'] / res['total_molecules']) * 100

    if logger == 'wandb':
        prefix = 'noisy' if test_on_noisy_graphs else ''
        for metric in ['avg_loss', 'edge_accuracy', 'atom_types_accuracy', 'formal_charges_accuracy', 'molecule_accuracy']:
            wandb.log({f'{test_or_val}_{prefix}_{metric}': res[metric]})

    print('Epoch %i %s avg loss: %.4f \nedge_accuracy: %.4f \natom_types_accuracy: %.4f \nformal_charges_accuracy: %.4f \nmolecule_accuracy: %.4f' % (epoch, test_or_val, res['avg_loss'], res['edge_accuracy'], res['atom_types_accuracy'], res['formal_charges_accuracy'], res['molecule_accuracy']))
    return res


def perturb_batch(batch, sigma, p, device):
    """
    Introduces noise to the molecule's 3D positions, atom types and formal charges that the model will learn to be robust against.
    The noise level is controlled by the parameters sigma and p.
    
    Args:
        batch (dict): contains different properties of the molecules
        sigma (float): the standard deviation of the Gaussian distribution used to perturb the 3D positions
        p (float): the probability of perturbing a single atom's type and formal charge

    Returns:
        noisy_batch (dict): a copy of the input batch with perturbed properties.
    """
    noisy_batch = copy.deepcopy(batch)

    # perturb positions: add isotropic gaussian noise with std=sigma to all atoms in all 3 directions
    noisy_positions = batch['positions'] + sigma * torch.randn_like(batch['positions']) * batch['atom_mask'].unsqueeze(-1)

    noisy_batch['positions'] = noisy_positions

    # perturb atom types: sample random atom types for all atoms and overwrite them with probability p, implemented by mask
    bs, n_nodes, n_atom_types = batch['one_hot'].shape
    # categorical distribution that samples one hot vectors of size n_atom_types
    atom_types_dist = torch.distributions.one_hot_categorical.OneHotCategorical(torch.ones(n_atom_types))
    noisy_one_hot = atom_types_dist.sample((bs, n_nodes)).to(device) # will be of same shape as batch['one_hot']
    mask_one_hot = torch.bernoulli(p * torch.ones_like(batch['atom_mask'])) # will be 1 with prob. p
    noisy_one_hot = (1-mask_one_hot).unsqueeze(-1) * batch['one_hot'] + mask_one_hot.unsqueeze(-1) * noisy_one_hot
    noisy_one_hot = noisy_one_hot * batch['atom_mask'].unsqueeze(-1)

    noisy_batch['one_hot'] = noisy_one_hot

    # perturb formal charges: similar to one_hot, but here we have three categories: -1, 0, 1
    formal_charges_dist = torch.distributions.categorical.Categorical(torch.ones(3)) # for charges -1, 0, 1
    noisy_formal_charges = formal_charges_dist.sample(batch['charges'].shape).to(device) - 1 # samples 0,1,2 then we subtract 1 to get charges
    mask_formal_charges = torch.bernoulli(p * torch.ones_like(batch['atom_mask']))
    noisy_formal_charges = (1-mask_formal_charges).unsqueeze(-1) * batch['charges'] + mask_formal_charges.unsqueeze(-1) * noisy_formal_charges
    noisy_formal_charges = noisy_formal_charges * batch['atom_mask'].unsqueeze(-1)

    noisy_batch['charges'] = noisy_formal_charges

    return noisy_batch
