from util.utils import assert_mean_zero_with_mask, remove_mean_with_mask,\
    assert_correctly_masked, sample_center_gravity_zero_gaussian_with_mask
import qm9.visualizer as vis
from qm9.analyze import analyze_stability_for_molecules
from qm9.sampling import sample_chain, sample, sample_sweep_conditional
import utils
import qm9.utils as qm9utils
from qm9 import losses
import torch
import logging

def train_epoch(args, loader, epoch, model, model_dp, model_ema, ema, device, dtype, property_norms, optim,
                nodes_dist, gradnorm_queue, dataset_info, prop_dist,lr_scheduler,partition='train'):
    if partition == 'train':
        lr_scheduler.step()
        model_dp.train()
        model.train()
    else:
        # model_dp.eval()
        # model.eval()
        model_ema.eval()
    res = {'loss': 0, 'counter': 0, 'loss_arr':[]}
    n_iterations = len(loader)
    for i, data in enumerate(loader):
        x = data['positions'].to(device, dtype)
        batch_size, _ , _ = x.size()
        node_mask = data['atom_mask'].to(device, dtype).unsqueeze(2)
        edge_mask = data['edge_mask'].to(device, dtype)
        one_hot = data['one_hot'].to(device, dtype)
        charges = (data['charges'] if args.include_charges else torch.zeros(0)).to(device, dtype)

        x = remove_mean_with_mask(x, node_mask)

        if args.augment_noise > 0:
            # Add noise eps ~ N(0, augment_noise) around points.
            eps = sample_center_gravity_zero_gaussian_with_mask(x.size(), x.device, node_mask)
            x = x + eps * args.augment_noise

        x = remove_mean_with_mask(x, node_mask)
        if args.data_augmentation:
            x = utils.random_rotation(x).detach()

        check_mask_correct([x, one_hot, charges], node_mask)
        assert_mean_zero_with_mask(x, node_mask)

        h = {'categorical': one_hot, 'integer': charges}

        for key in args.conditioning:
            properties = data[key]
            label = (properties - property_norms[key]['mean']) / property_norms[key]['mad']
            label = label.to(device,dtype)

        if partition == 'train':
            optim.zero_grad()
            # transform batch through flow
            loss = losses.compute_loss(model_dp, x, h, node_mask, edge_mask, label)
            loss.backward()
            if args.clip_grad:
                grad_norm = utils.gradient_clipping(model, gradnorm_queue)
            else:
                grad_norm = 0.

            optim.step()
        else:
            print('ema')
            loss = losses.compute_loss(model_ema, x, h, node_mask, edge_mask, label)


        res['loss'] += loss.item() * batch_size
        res['counter'] += batch_size
        res['loss_arr'].append(loss.item())

        # Update EMA if enabled.
        if partition == 'train':
            if args.ema_decay > 0:
                ema.update_model_average(model_ema, model)

        if i % args.n_report_steps == 0:
            # logging.info(f" Epoch: {epoch}, iter: {i}/{n_iterations}, "
            #       f"Loss {sum(res['loss_arr'][-10:])/len(res['loss_arr'][-10:]):.4f} "
            #       f"GradNorm: {grad_norm:.1f}")

            logging.info(f" Epoch: {epoch}, iter: {i}/{n_iterations}, "
                         f"Loss {sum(res['loss_arr'][-10:]) / len(res['loss_arr'][-10:]):.4f} "
                        )

        if args.break_train_epoch:
            break
    return res['loss'] / res['counter']

def check_mask_correct(variables, node_mask):
    for i, variable in enumerate(variables):
        if len(variable) > 0:
            assert_correctly_masked(variable, node_mask)


def test(args, loader, epoch, eval_model, device, dtype, property_norms, nodes_dist, partition='Test'):
    eval_model.eval()

    res = {'loss': 0, 'counter': 0, 'loss_arr':[]}
    n_iterations = len(loader)
    with torch.no_grad():
        for i, data in enumerate(loader):
            x = data['positions'].to(device, dtype)
            batch_size, _ , _ = x.size()
            node_mask = data['atom_mask'].to(device, dtype).unsqueeze(2)
            edge_mask = data['edge_mask'].to(device, dtype)
            one_hot = data['one_hot'].to(device, dtype)
            charges = (data['charges'] if args.include_charges else torch.zeros(0)).to(device, dtype)

            if args.augment_noise > 0:
                # Add noise eps ~ N(0, augment_noise) around points.
                eps = sample_center_gravity_zero_gaussian_with_mask(x.size(), x.device, node_mask)
                x = x + eps * args.augment_noise

            x = remove_mean_with_mask(x, node_mask)


            check_mask_correct([x, one_hot, charges], node_mask)
            assert_mean_zero_with_mask(x, node_mask)

            h = {'categorical': one_hot, 'integer': charges}

            label = qm9utils.prepare_context(args.conditioning, data, property_norms).to(device, dtype)
            # preprocessing: (Bs,),mu-mean/mad -> repeat:(Bs,N,1) -> *atom_mask:(Bs,N,1)
            assert_correctly_masked(label, node_mask)


            # transform batch through flow
            loss = losses.compute_loss(eval_model, x, h, node_mask, edge_mask, label)

            res['loss'] += loss.item() * batch_size
            res['counter'] += batch_size
            res['loss_arr'].append(loss.item())


            if i % args.n_report_steps == 0:
                logging.info(f"{partition} Epoch: {epoch}, iter: {i}/{n_iterations}, "
                      f"Loss {sum(res['loss_arr'][-10:])/len(res['loss_arr'][-10:]):.4f} ")

    return res['loss'] / res['counter']



def save_and_sample_chain(model, args, device, dataset_info, prop_dist,
                          epoch=0, id_from=0, batch_id=''):
    one_hot, charges, x = sample_chain(args=args, device=device, flow=model,
                                       n_tries=1, dataset_info=dataset_info, prop_dist=prop_dist)

    vis.save_xyz_file(f'outputs/{args.exp_name}/epoch_{epoch}_{batch_id}/chain/',
                      one_hot, charges, x, dataset_info, id_from, name='chain')

    return one_hot, charges, x


def sample_different_sizes_and_save(model, nodes_dist, args, device, dataset_info, prop_dist,
                                    n_samples=5, epoch=0, batch_size=100, batch_id=''):
    batch_size = min(batch_size, n_samples)
    for counter in range(int(n_samples/batch_size)):
        nodesxsample = nodes_dist.sample(batch_size)
        one_hot, charges, x, node_mask = sample(args, device, model, prop_dist=prop_dist,
                                                nodesxsample=nodesxsample,
                                                dataset_info=dataset_info)
        print(f"Generated molecule: Positions {x[:-1, :, :]}")
        vis.save_xyz_file(f'outputs/{args.exp_name}/epoch_{epoch}_{batch_id}/', one_hot, charges, x, dataset_info,
                          batch_size * counter, name='molecule')


def analyze_and_save(epoch, model_sample, nodes_dist, args, device, dataset_info, prop_dist,
                     n_samples=1000, batch_size=100):
    print(f'Analyzing molecule stability at epoch {epoch}...')
    batch_size = min(batch_size, n_samples)
    assert n_samples % batch_size == 0
    molecules = {'one_hot': [], 'x': [], 'node_mask': []}
    for i in range(int(n_samples/batch_size)):
        nodesxsample = nodes_dist.sample(batch_size)
        one_hot, charges, x, node_mask = sample(args, device, model_sample, dataset_info, prop_dist,
                                                nodesxsample=nodesxsample)

        molecules['one_hot'].append(one_hot.detach().cpu())
        molecules['x'].append(x.detach().cpu())
        molecules['node_mask'].append(node_mask.detach().cpu())

    molecules = {key: torch.cat(molecules[key], dim=0) for key in molecules}
    validity_dict, rdkit_tuple = analyze_stability_for_molecules(molecules, dataset_info)

    wandb.log(validity_dict)
    if rdkit_tuple is not None:
        wandb.log({'Validity': rdkit_tuple[0][0], 'Uniqueness': rdkit_tuple[0][1], 'Novelty': rdkit_tuple[0][2]})
    return validity_dict


def save_and_sample_conditional(args, device, model, prop_dist, dataset_info, epoch=0, id_from=0):
    one_hot, charges, x, node_mask = sample_sweep_conditional(args, device, model, dataset_info, prop_dist)

    vis.save_xyz_file(
        'outputs/%s/epoch_%d/conditional/' % (args.exp_name, epoch), one_hot, charges, x, dataset_info,
        id_from, name='conditional', node_mask=node_mask)

    return one_hot, charges, x
