import copy

import numpy as np
import torch
import torch.nn.functional as F
from equivariant_diffusion.utils import assert_mean_zero_with_mask, remove_mean_with_mask,\
    assert_correctly_masked
from qm9.analyze import check_stability


def rotate_chain(z):
    assert z.size(0) == 1

    z_h = z[:, :, 3:]

    n_steps = 30
    theta = 0.6 * np.pi / n_steps
    Qz = torch.tensor(
        [[np.cos(theta), -np.sin(theta), 0.],
         [np.sin(theta), np.cos(theta), 0.],
         [0., 0., 1.]]
    ).float()
    Qx = torch.tensor(
        [[1., 0., 0.],
         [0., np.cos(theta), -np.sin(theta)],
         [0., np.sin(theta), np.cos(theta)]]
    ).float()
    Qy = torch.tensor(
        [[np.cos(theta), 0., np.sin(theta)],
         [0., 1., 0.],
         [-np.sin(theta), 0., np.cos(theta)]]
    ).float()

    Q = torch.mm(torch.mm(Qz, Qx), Qy)

    Q = Q.to(z.device)

    results = []
    results.append(z)
    for i in range(n_steps):
        z_x = results[-1][:, :, :3]
        # print(z_x.size(), Q.size())
        new_x = torch.matmul(z_x.view(-1, 3), Q.T).view(1, -1, 3)
        # print(new_x.size())
        new_z = torch.cat([new_x, z_h], dim=2)
        results.append(new_z)

    results = torch.cat(results, dim=0)
    return results


def reverse_tensor(x):
    return x[torch.arange(x.size(0) - 1, -1, -1)]


def sample_chain(args, device, flow, n_tries, dataset_info, prop_dist=None):
    n_samples = 1
    if args.dataset == 'qm9' or args.dataset == 'qm9_second_half' or args.dataset == 'qm9_first_half':
        if dataset_info['with_h']:
            n_nodes = 19
        else:
            n_nodes = 9
    elif args.dataset == 'zinc250k':
        if dataset_info['with_h']:
            n_nodes = 44
        else:
            n_nodes = 24
    elif args.dataset == 'geom':
        n_nodes = 44
    else:
        raise ValueError()

    # TODO FIX: This conditioning just zeros.
    if args.context_node_nf > 0:
        context = prop_dist.sample(n_nodes).unsqueeze(1).unsqueeze(0)
        context = context.repeat(1, n_nodes, 1).to(device)
        #context = torch.zeros(n_samples, n_nodes, args.context_node_nf).to(device)
    else:
        context = None

    node_mask = torch.ones(n_samples, n_nodes, 1).to(device)

    edge_mask = (1 - torch.eye(n_nodes)).unsqueeze(0)
    edge_mask = edge_mask.repeat(n_samples, 1, 1).view(-1, 1).to(device)

    if args.probabilistic_model == 'diffusion':
        one_hot, charges, x = None, None, None
        for i in range(n_tries):
            chain = flow.sample_chain(n_samples, n_nodes, node_mask, edge_mask, context, keep_frames=100)
            chain = reverse_tensor(chain)

            # Repeat last frame to see final sample better.
            chain = torch.cat([chain, chain[-1:].repeat(10, 1, 1)], dim=0)
            x = chain[-1:, :, 0:3]
            one_hot = chain[-1:, :, 3:-1]
            one_hot = torch.argmax(one_hot, dim=2)
            charges = torch.round(chain[-1:, :, -1]).long()
            # avoid numerical issues with very large charges sampled at the beginning of training
            charges = torch.clamp(charges, min=-1000000, max=1000000)

            atom_type = one_hot.squeeze(0).cpu().detach().numpy()
            x_squeeze = x.squeeze(0).cpu().detach().numpy()
            charges = charges.squeeze(0).cpu().detach().numpy()
            mol_stable = check_stability(x_squeeze, atom_type, charges, dataset_info)[0]

            # Prepare entire chain.
            x = chain[:, :, 0:3]
            one_hot = chain[:, :, 3:-1]
            one_hot = F.one_hot(torch.argmax(one_hot, dim=2), num_classes=len(dataset_info['atom_decoder']))
            charges = torch.round(chain[:, :, -1:]).long()

            if mol_stable:
                print('Found stable molecule to visualize :)')
                break
            elif i == n_tries - 1:
                print('Did not find stable molecule, showing last sample.')

    else:
        raise ValueError

    return one_hot, charges, x


def sample(args, device, generative_model, dataset_info,
           prop_dist=None, nodesxsample=torch.tensor([10]), context=None,
           fix_noise=False, prop_encoder=None, enforce_unconditional_generation=False,
           regressor_guidance=False):
    """
    Args:
        enforce_unconditional_generation (bool): only relevant when using a conditional model trained 
                        with condition dropout. If True, set conditioning vector to (0,1) to generate 
                        unconditional molecules.
    """
    max_n_nodes = dataset_info['max_n_nodes']  # this is the maximum node_size in the dataset

    # make sure we did not sample more nodes than in the training set
    assert int(torch.max(nodesxsample)) <= max_n_nodes
    # use the highest sampled number of nodes as the n_nodes for this batch
    # this saves computation time and memory
    # this will fail when concatenating all the molecules outside. So we will pad again after generation
    max_n_nodes = int(torch.max(nodesxsample))
    print(f'max_n_nodes in the current batch: {max_n_nodes}')
    batch_size = len(nodesxsample)

    node_mask = torch.zeros(batch_size, max_n_nodes)
    for i in range(batch_size):
        node_mask[i, 0:nodesxsample[i]] = 1

    # Compute edge_mask

    edge_mask = node_mask.unsqueeze(1) * node_mask.unsqueeze(2)
    diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0)
    edge_mask *= diag_mask
    edge_mask = edge_mask.view(batch_size * max_n_nodes * max_n_nodes, 1).to(device)
    node_mask = node_mask.unsqueeze(2).to(device)

    # TODO FIX: This conditioning just zeros.
    # TODO: check what the previous todo mean?
    # TODO: fix when conditioning on mulltiple properties.
    if args.context_node_nf > 0:
        if context is None:
            context = prop_dist.sample_batch(nodesxsample)
        context_global = copy.deepcopy(context)
        context_global = prop_dist.unnormalize_tensor(context_global, args.conditioning[0])
        if prop_encoder is not None:
            context = prop_encoder(context.squeeze())
        if args.condition_dropout:
            if enforce_unconditional_generation:
                # we're using a conditional model trained with condition dropout but want to generate unconditional samples
                context = torch.cat((torch.zeros_like(context), torch.ones_like(context)), dim=1)
            else:
                # we're using a conditional model trained with condition dropout and want conditional mols
                # for sampling, condition on the true conditioning vectors to assess the generation quality and the mae to the conditions
                context = torch.cat((context, torch.zeros_like(context)), dim=1)
        context = context.unsqueeze(1).repeat(1, max_n_nodes, 1).to(device) * node_mask
    else:
        context = None
        context_global = None

    if regressor_guidance:
        if context is None:
            context_guidance = prop_dist.sample_batch(nodesxsample)
            context_global = copy.deepcopy(context_guidance)
            context_global = prop_dist.unnormalize_tensor(context_global, list(prop_dist.normalizer.keys())[0])
        else:
            # regressor guidance on top of conditional model
            context_guidance = context[:, 0, 0]
        generative_model.target_prop = context_guidance.squeeze().to(device)

    if args.probabilistic_model == 'diffusion':
        x, h = generative_model.sample(batch_size, max_n_nodes, node_mask, edge_mask, context, fix_noise=fix_noise)

        assert_correctly_masked(x, node_mask)
        assert_mean_zero_with_mask(x, node_mask)

        atom_types = h['atom_types']
        formal_charges = h['formal_charges']
        adjacency_matrices = h['adjacency_matrices']
        z_h = h['z_h'] # latent features, useful for debugging

    else:
        raise ValueError(args.probabilistic_model)

    # we need to pad all atributes to have the maximal dimension along the atom axis
    # so that we can assemble them in the calling routine
    n_pad = dataset_info['max_n_nodes'] - max_n_nodes
    x = F.pad(x, (0,0,0,n_pad), "constant", 0)
    atom_types = F.pad(atom_types, (0,n_pad), "constant", 0)
    formal_charges = F.pad(formal_charges, (0,n_pad), "constant", 0)
    adjacency_matrices = F.pad(adjacency_matrices, (0,n_pad,0,n_pad), "constant", 0)
    node_mask = F.pad(node_mask, (0,0,0,n_pad), "constant", 0)
    edge_mask = F.pad(edge_mask.reshape(batch_size, max_n_nodes, max_n_nodes), (0,n_pad,0,n_pad), "constant", 0).reshape(-1,1)
    z_h = F.pad(z_h, (0,0,0,n_pad), "constant", 0)

    return x, atom_types, formal_charges, adjacency_matrices, node_mask, edge_mask, z_h, context_global


def sample_sweep_conditional(args, device, generative_model, dataset_info, prop_dist, n_nodes=19, n_frames=100):
    nodesxsample = torch.tensor([n_nodes] * n_frames)

    context = []
    for key in prop_dist.distributions:
        min_val, max_val = prop_dist.distributions[key][n_nodes]['params']
        mean, mad = prop_dist.normalizer[key]['mean'], prop_dist.normalizer[key]['mad']
        min_val = (min_val - mean) / (mad)
        max_val = (max_val - mean) / (mad)
        context_row = torch.tensor(np.linspace(min_val, max_val, n_frames)).unsqueeze(1)
        context.append(context_row)
    context = torch.cat(context, dim=1).float().to(device)

    one_hot, charges, x, node_mask = sample(args, device, generative_model, dataset_info, prop_dist, nodesxsample=nodesxsample, context=context, fix_noise=True)
    return one_hot, charges, x, node_mask