import torch
import numpy as np

###############=============================##############
#Adopted from https://github.com/kvfrans/shortcut-models/blob/main/targets_shortcut.py
##########################################################

def get_targets_ST(args, model, images, labels):
    # 1) =========== Sample dt. ============
    bootstrap_batchsize = args.batch_size // args.bootstrap_every  #how many elements of current batch to use for self-consistency

    # total number of shortcut options, for denoise_timesteps=128, options are, 1, 2, 4, 8, 16, 32, 64 steps, i.e. log(128) options
    log2_sections = int(torch.log2(torch.tensor(args.denoise_timesteps)).item())
    len_log2_sections = int((bootstrap_batchsize / 8) * log2_sections)

    dt_base = torch.repeat_interleave(torch.arange(log2_sections - 1, -1, -1, dtype=torch.int32),
                                      len_log2_sections // log2_sections)
    dt_base = torch.cat([dt_base, torch.zeros(bootstrap_batchsize - len(dt_base), dtype=torch.int32)]).to(args.device)


    dt = 1 / (2 ** dt_base)
    dt_base_bootstrap = dt_base + 1
    dt_bootstrap = dt / 2

    # 2) =========== Sample t. ============
    dt_sections = 2 ** dt_base

    t = torch.randint(0, dt_sections.int().max().item(), (bootstrap_batchsize,), dtype=torch.float32).to(args.device)
    t = t % dt_sections
    t /= dt_sections
    t_full = t.view(-1, 1, 1, 1)

    # 3) =========== Generate Bootstrap Targets ============
    x_1 = images[:bootstrap_batchsize]
    x_0 = torch.randn_like(x_1)  # Sampled from a normal distribution
    x_t = (1 - (1 - 1e-5) * t_full) * x_0 + t_full * x_1
    bst_labels = labels[:bootstrap_batchsize]
    model.eval()
    with torch.no_grad():
        v_b1 = model(x_t, t, bst_labels, dt_base_bootstrap)
        t2 = t + dt_bootstrap
        x_t2 = x_t + dt_bootstrap.view(-1, 1, 1, 1) * v_b1
        x_t2 = torch.clamp(x_t2, -4, 4)
        v_b2 = model(x_t2, t2, bst_labels, dt_base_bootstrap)
        v_target = (v_b1 + v_b2) / 2

    v_target = torch.clamp(v_target, -4, 4)
    bst_v = v_target
    bst_dt = dt_base
    bst_t = t
    bst_xt = x_t
    bst_l = bst_labels

    # 4) =========== Generate Flow-Matching Targets ============
    # Sample t.
    t = torch.randint(0, args.denoise_timesteps, (images.shape[0],), dtype=torch.float32).to(args.device)
    t /= args.denoise_timesteps
    t_full = t.view(-1, 1, 1, 1)  # [batch, 1, 1, 1]

    # Sample flow pairs x_t, v_t.
    x_0 = torch.randn_like(images)  # Sampled from a normal distribution
    x_1 = images
    x_t = (1 - (1 - 1e-5) * t_full) * x_0 + t_full * x_1
    v_t = x_1 - (1 - 1e-5) * x_0

    dt_flow = int(torch.log2(torch.tensor(args.denoise_timesteps)).item())  # Convert to Python int
    dt_base = torch.full((images.shape[0],), dt_flow, dtype=torch.int32).to(args.device)

    # ==== 5) Merge Flow+Bootstrap ====
    x_t = torch.cat([bst_xt, x_t], dim=0)
    t = torch.cat([bst_t, t], dim=0)
    dt_base = torch.cat([bst_dt, dt_base], dim=0)
    v_t = torch.cat([bst_v, v_t], dim=0)
    labels = torch.cat([bst_l, labels], dim=0)

    return x_t, v_t, t, dt_base, labels


def loss_ST(args, inputs, model, loss_dict):
    x_t, v_t, t, dt_base, y_ = inputs
    v_pred = model(x_t, t, y_, dt_base)
    loss = torch.mean((v_pred - v_t) ** 2)
    loss_dict['train_loss'] += loss
    return loss

