import wandb
import torch
from tqdm import tqdm
import time
import numpy as np
from utility import functions as uf
# from traintest.sampling import sample_chain, sample, sample_sweep_conditional

def k_nearest(points, k):
    random_point_index = torch.randint(0, points.size(0), (1,)).to(points.device)
    random_point = points[random_point_index]

    # Calculate the distance between the random point and all other points
    distances = torch.cdist(random_point, points)

    # Find the k-nearest points
    _, nearest_indices = torch.topk(distances, k+1, largest=False)  # Add 1 to k to exclude the random point itself

    # Remove the random point index from the nearest indices
    nearest_indices = nearest_indices[0][nearest_indices[0] != random_point_index]

    return random_point_index, nearest_indices

def prepare_batch_data(args, data, device, dtype):
    pos_curr = data.pos.to(device, dtype)[:,:3]
    pos_next = data.pos.to(device, dtype)[:,3:]
    pro_mol_cutoff = torch.tensor(data.pro_mol_cutoff).to(device, dtype)
    num_atoms = torch.diff(data.ptr).tolist()
    pos_curr = torch.split_with_sizes(pos_curr, num_atoms)
    pos_next = torch.split_with_sizes(pos_next, num_atoms)
    x = torch.split_with_sizes(data.x.to(device, dtype), num_atoms)
    max_atoms = torch.max(torch.diff(data.ptr))
    row_indices = torch.arange(max_atoms).unsqueeze(0)
    node_mask = (row_indices < torch.diff(data.ptr).unsqueeze(1)).int().unsqueeze(-1).to(device)
    pos_curr = uf.padding_list_tensor(pos_curr)
    pos_next = uf.padding_list_tensor(pos_next)
    x = uf.padding_list_tensor(x)
    return pos_curr, pos_next, x, pro_mol_cutoff, node_mask

def train_epoch(args, loader, epoch, model, model_ema, ema, device, dtype, optim, gradnorm_queue, rank, dist):
    # model.train()
    nll_epoch = []
    n_iterations = len(loader)
    if rank == 0:
        tqdm_info = tqdm(range(n_iterations))
    loss = None
    for i, data in enumerate(loader):
        pos_curr, pos_next, x, pro_mol_cutoff, node_mask = prepare_batch_data(args, data, device, dtype)
        if pos_curr.size(1) < 10000:
            optim.zero_grad()
            # transform batch through flow
            nll, reg_term, mean_abs_z = compute_loss_and_nll(args, model, pos_curr, pos_next, x, pro_mol_cutoff, node_mask)
            # standard nll from forward KL
            loss = nll
            loss.backward()

            if args.clip_grad:
                grad_norm = uf.gradient_clipping(model, gradnorm_queue)
            else:
                grad_norm = 0.

            optim.step()
            # Update EMA if enabled.
            if args.ema_decay > 0:
                ema.update_model_average(model_ema, model)
            # dist.barrier()
        # torch.cuda.synchronize()
        if rank == 0 and loss is not None:
            train_des = (f"\rE: {epoch}, ({pos_curr.size(1)})"
                         f"L {loss.item():.2f}, NLL: {nll.item():.2f}, "
                         f"GN: {grad_norm:.2f}")
            tqdm_info.update(1)
            tqdm_info.set_description(train_des)
            if args.save_model and (i % 10000 == 0 or i == 85000) and dist is not None:
                nll_epoch.append(nll.item())
                args.current_epoch = epoch
                args.current_iteration = i
                checkpoint = {
                    "model": model.module.state_dict(),
                    "model_ema": model_ema.state_dict(),
                    "opt": optim.state_dict(),
                    "args": args
                }
                checkpoint_path = f"{args.checkpoint_dir}/{args.exp_name}.pt"
                torch.save(checkpoint, checkpoint_path)
                checkpoint_path = f"{args.checkpoint_dir}/{args.exp_name}_e{epoch}.pt"
                torch.save(checkpoint, checkpoint_path)
            dist.barrier()
    # if rank == 0:
    #     wandb.log({"Train Epoch NLL": np.mean(nll_epoch)}, commit=False)
    return loss.item()


def test(args, loader, epoch, eval_model, device, dtype, rank, partition='Test'):
    eval_model.eval()
    with torch.no_grad():
        nll_epoch = 0
        n_samples = 0

        n_iterations = len(loader)
        if rank == 0:
            loader = tqdm(loader, ncols=80)
        for i, data in enumerate(loader):
            pos_curr, pos_next, x, pro_mol_cutoff, node_mask = prepare_batch_data(args, data, device, dtype)
            if pos_curr.size(1) > 10000:
                continue
            batch_size = pos_curr.size(0)
            nll, reg_term, mean_abs_z = compute_loss_and_nll(args, eval_model, pos_curr, pos_next, x, pro_mol_cutoff,
                                                             node_mask)
            nll_epoch += nll.item() * batch_size
            n_samples += batch_size
            if rank == 0:
                train_des = (f"\r {partition} \t e: {epoch}, i: {i}/{n_iterations}, "
                             f"NLL: {nll_epoch/n_samples:.2f}")
                loader.set_description(train_des)

    return nll_epoch/n_samples


def simulate(args, model, pos, x, device, dtype):
    model.eval()
    pos_curr = pos[0].unsqueeze(0).to(device, dtype)
    pos_next = pos[1:].to(device, dtype)
    x = x.unsqueeze(0).to(device, dtype)
    pro_mol_cutoff = 0
    node_mask = torch.ones_like(pos_curr[:, :, 0]).unsqueeze(-1)
    traj_predict = torch.zeros_like(pos_next)

    # Start the timer
    start_time = time.time()
    for index in range(pos_next.size(0)):
        pos_curr = model.simulate(pos_curr, x, pro_mol_cutoff, node_mask)
        traj_predict[index] = pos_curr[0].clone()

    # Calculate the elapsed time
    elapsed_time = time.time() - start_time
    loss = torch.nn.MSELoss()
    mse = {'average': loss(traj_predict, pos_next), 'next': loss(traj_predict[0], pos_next[0]), 'last': loss(traj_predict[-1], pos_next[-1])}

    traj_predict = torch.cat((pos[0].unsqueeze(0).to(device, dtype), traj_predict), dim=0)
    return traj_predict, mse, elapsed_time



def compute_loss_and_nll(args, generative_model, pos_curr, pos_next, x, pro_mol_cutoff, node_mask):
    for feat in [pos_curr, pos_next, x]:
        uf.assert_correctly_masked(feat, node_mask)

    pos_predict = generative_model(pos_curr, x, pro_mol_cutoff, node_mask)
    # Average over batch.
    loss = torch.nn.MSELoss()
    nll = loss(pos_predict, pos_next)

    reg_term = torch.tensor([0.]).to(nll.device)
    mean_abs_z = 0.
    return nll, reg_term, mean_abs_z