import torch
import numpy as np

###############=============================##############
#Adopted from https://github.com/kvfrans/shortcut-models/blob/main/targets_shortcut.py
##########################################################

def get_targets_ST_CSL(args, model, images, labels):

    # 1) =========== Sample dt. ============
    bootstrap_batchsize = args.batch_size // (args.bootstrap_every*2)  #how many elements of current batch to use for self-consistency

    # total number of shortcut options, for denoise_timesteps=128, options are, 1, 2, 4, 8, 16, 32, 64 steps, i.e. log(128) options
    log2_sections = int(torch.log2(torch.tensor(args.denoise_timesteps)).item())

    dt_base = torch.repeat_interleave(torch.arange(log2_sections - 1, -1, -1, dtype=torch.int32),
                                      bootstrap_batchsize // log2_sections)
    num_zeros = (bootstrap_batchsize // log2_sections) + bootstrap_batchsize - len(dt_base)
    dt_base = torch.cat([dt_base, torch.zeros(bootstrap_batchsize - len(dt_base), dtype=torch.int32)])
    dt_base = torch.cat([dt_base, torch.zeros(num_zeros, dtype=torch.int32)]).to(args.device)
    bootstrap_batchsize = bootstrap_batchsize + num_zeros
    num_zeros = num_zeros * 2

    dt = 1 / (2 ** dt_base)
    dt_base_bootstrap = dt_base + 1
    dt_bootstrap = dt / 2

    # 2) =========== Sample t. ============
    dt_sections = 2 ** dt_base  # [64, 64, 32, 32, 16, 16, 8, 8, 4, 4, 2, 2, 1, 1, 1, 1]
    dt_sections = dt_sections / 2
    t = torch.randint(0, dt_sections.int().max().item(), (bootstrap_batchsize,), dtype=torch.float32).to(args.device)
    # Ensure values are within [0, dt_sections):
    t = t % dt_sections
    # Normalize between 0 and 1
    t /= dt_sections
    #================
    # Reshape to [batch, 1, 1, 1]
    t_full = t.view(-1, 1, 1, 1)

    # 3) =========== Generate Bootstrap Targets ============
    x_1 = images[:bootstrap_batchsize]
    x_0 = torch.randn_like(x_1)  # Sampled from a normal distribution
    x_t = (1 - (1 - 1e-5) * t_full) * x_0 + t_full * x_1
    bst_labels = labels[:bootstrap_batchsize]
    model.eval()
    with torch.no_grad():
        v_b1 = model(x_t, t, bst_labels, dt_base_bootstrap)
        t2 = t + dt_bootstrap
        x_t2 = x_t + dt_bootstrap.view(-1, 1, 1, 1) * v_b1
        x_t2 = torch.clamp(x_t2, -4, 4)
        v_b2 = model(x_t2, t2, bst_labels, dt_base_bootstrap)
        v_target = (v_b1 + v_b2) / 2
        v_target = torch.clamp(v_target, -4, 4)

        ###################################
        # get output from 2-step model here
        v_star = model(x_t, t, bst_labels, dt_base_bootstrap - 1)
        x_t3 = x_t + dt_bootstrap.view(-1, 1, 1, 1) * 2 * v_star
        x_t3 = torch.clamp(x_t3, -4, 4)
        t3 = (t + dt_bootstrap * 2)
        ###################################

        v_b3 = model(x_t3, t3, bst_labels, dt_base_bootstrap)
        x_t4 = x_t3 + dt_bootstrap.view(-1, 1, 1, 1) * v_b3
        t4 = t3 + dt_bootstrap
        x_t4 = torch.clamp(x_t4, -4, 4)
        v_b4 = model(x_t4, t4, bst_labels, dt_base_bootstrap)
        v_target2 = (v_b3 + v_b4) / 2
        v_target2 = torch.clamp(v_target2, -4, 4)[:-num_zeros]


    bst_v = v_target
    bst_dt = dt_base
    bst_t = t
    bst_xt = x_t
    bst_l = bst_labels

    bst_v_oc = v_target2
    bst_dt_bootstrap = dt[:-num_zeros]
    # ===========================================================================

    # 4) =========== Generate Flow-Matching Targets ============
    # Sample t.
    t = torch.randint(0, args.denoise_timesteps, (images.shape[0],), dtype=torch.float32).to(args.device)
    t /= args.denoise_timesteps
    t_full = t.view(-1, 1, 1, 1)  # [batch, 1, 1, 1]

    # Sample flow pairs x_t, v_t.
    x_0 = torch.randn_like(images)  # Sampled from a normal distribution
    x_1 = images
    x_t = (1 - (1 - 1e-5) * t_full) * x_0 + t_full * x_1
    v_t = x_1 - (1 - 1e-5) * x_0

    dt_flow = int(torch.log2(torch.tensor(args.denoise_timesteps)).item())  # Convert to Python int
    dt_base = torch.full((images.shape[0],), dt_flow, dtype=torch.int32).to(args.device)

    # ==== 5) Merge Flow+Bootstrap ====
    x_t = torch.cat([bst_xt, x_t], dim=0)
    t = torch.cat([bst_t, t], dim=0)
    dt_base = torch.cat([bst_dt, dt_base], dim=0)
    v_t = torch.cat([bst_v, v_t], dim=0)
    labels = torch.cat([bst_l, labels], dim=0)

    return x_t, v_t, t, dt_base, labels, bst_v_oc, bst_dt_bootstrap


def loss_ST_CSL(args, inputs, model, loss_dict):
    bootstrap_batchsize = args.batch_size // (args.bootstrap_every*2)
    log2_sections = int(torch.log2(torch.tensor(args.denoise_timesteps)).item())
    repitition = bootstrap_batchsize // log2_sections
    num_zeros = (repitition + bootstrap_batchsize - log2_sections * repitition)

    x_t, v_t, t, dt_base, y_, v_oc, dt_bootstrap = inputs

    v_pred = model(x_t, t, y_, dt_base)
    loss_bst = ((v_pred - v_t) ** 2)[:bootstrap_batchsize+num_zeros]
    loss_fm = ((v_pred - v_t) ** 2)[bootstrap_batchsize+num_zeros:]

    v_t2 = v_pred[:bootstrap_batchsize]
    x_t2 = x_t[:bootstrap_batchsize]
    t2 = t[:bootstrap_batchsize]
    y3 = y_[:bootstrap_batchsize]
    dt_base3 = dt_base[:bootstrap_batchsize]

    v_t2 = v_t2[:-num_zeros]
    x_t2 = x_t2[:-num_zeros]
    t2 = t2[:-num_zeros]
    y3 = y3[:-num_zeros]
    dt_base3 = dt_base3[:-num_zeros]

    x_t3 = x_t2 + dt_bootstrap.view(-1, 1, 1, 1) * v_t2
    t3 = t2 + dt_bootstrap

    v_pred2 = model(x_t3, t3, y3, dt_base3)

    loss2 = (v_pred2 - v_oc) ** 2
    loss2 = loss2
    loss = torch.mean(torch.cat([loss_fm, loss_bst, loss2], dim=0))

    loss_dict['train_loss'] += loss

    return loss
