import torch
from torch.autograd import grad
import torch.optim as optim
import torch.linalg as tla
import numpy as np
import torch.nn.functional as F

from scheduler import *
from diffusion import cosine_beta_schedule

# Define global variables
TIMESTEPS = 300

# set_scheduler(timesteps) as global variables
# define beta schedule
betas = cosine_beta_schedule(timesteps=TIMESTEPS)

# define alphas
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

def normalize_normals_torch(normals):
    if normals.shape[1] == 3:
      norm = torch.sqrt(torch.sum(normals**2, axis=1, keepdims=True))
    elif normals.shape[0] == 3:
      norm = torch.sqrt(torch.sum(normals**2, axis=0, keepdims=True))
    else:
      norm = torch.sqrt(torch.sum(normals**2, axis=-1, keepdims=True))
    unit_normals = normals / norm
    return unit_normals

def angle_vector_compute(vec1, vec2):
    angles_prod = torch.einsum('ijk, ijk->jk', vec1, vec2)
    angles_clamp = torch.clamp(angles_prod, min = -1+1e-4, max = 1-1e-4)
    angles = torch.acos(angles_clamp)
    return angles

# linear trend along boundary seams
def linear_angle_loss(image, mask, normalize = False):
    if normalize:
      image = normalize_normals_torch(image)

    _, height, width = image[0].shape
    assert height == width

    mask_0 = ((torch.arange(height) + 1) % 16 == 0) # location 15, 31, 63, ...
    mask_0_prev = ((torch.arange(height) + 2) % 16 == 0) # location 14, 30, ...
    mask_1 = ((torch.arange(height) % 16 == 0)) # location 0, 16, 32, ...
    mask_1_next = ((torch.arange(height) - 1) % 16 == 0) # location 1, 17, 33, ...

    mask_0[-1] = False
    mask_0_prev[-2] = False
    mask_1[0] = False
    mask_1_next[1] = False

    image = image[0]

    # n_0_prev, n_0, | n_1, n_1_next, compute the loss between predicted n_1' & n_1, predicted n_0' and n_0

    angle_interp_next_td = image[:, mask_0] + (image[:, mask_0] - image[:, mask_0_prev]) # next vector pred from prev patch
    angle_interp_prev_td = image[:, mask_1] - (image[:, mask_1_next] - image[:, mask_1]) # prev vector pred from next patch
    angle_next_td = angle_vector_compute(angle_interp_next_td / tla.norm(angle_interp_next_td, dim = 0), image[:, mask_1] / tla.norm(image[:, mask_1], dim = 0))
    angle_prev_td = angle_vector_compute(angle_interp_prev_td / tla.norm(angle_interp_prev_td, dim = 0), image[:, mask_0] / tla.norm(image[:, mask_0], dim = 0))
    loss_td = (angle_next_td + angle_prev_td)/2 * (mask)[mask_0]

    angle_interp_next_lr = image[:, :, mask_0] + (image[:, :, mask_0] - image[:, :, mask_0_prev]) # next vector pred from prev patch
    angle_interp_prev_lr = image[:, :, mask_1] - (image[:, :, mask_1_next] - image[:, :, mask_1]) # prev vector pred from next patch
    angle_next_lr = angle_vector_compute(angle_interp_next_lr / tla.norm(angle_interp_next_lr, dim = 0), image[:, :, mask_1] / tla.norm(image[:, :, mask_1], dim = 0))
    angle_prev_lr = angle_vector_compute(angle_interp_prev_lr / tla.norm(angle_interp_prev_lr, dim = 0), image[:, :, mask_0] / tla.norm(image[:, :, mask_0], dim = 0))
    loss_lr = (angle_next_lr + angle_prev_lr)/2 * (mask)[:, mask_0]

    return loss_td, loss_lr

def integrability_loss(nx, ny, nz, mask = None):
    # this original implementation from SFT has a different axis direction for ny, need to flip (negative sign cancels out)
    p = -nx / nz
    q = ny / nz

    p[p < -10] = -10
    p[p > 10] = 10

    q[q < -10] = -10
    q[q > 10] = 10

    pi0j0 = p[:-1, :-1]   # p_{i,j}
    pi1j0 = p[1:,  :-1]   # p_{i+1,j}
    pi0j1 = p[:-1, 1: ]   # p_{i,j+1}
    pi1j1 = p[1:,  1: ]   # p_{i+1,j+1}

    qi0j0 = q[:-1, :-1]   # q_{i,j}
    qi1j0 = q[1:,  :-1]   # q_{i+1,j}
    qi0j1 = q[:-1, 1: ]   # q_{i,j+1}
    qi1j1 = q[1:,  1: ]   # q_{i+1,j+1}

    loss = torch.square(pi0j0 + pi0j1 - pi1j0 - pi1j1 - qi0j0 + qi0j1 - qi1j0 + qi1j1)
    if mask is not None:
      loss = loss[mask]
    return loss

def guidance_DDIM(input_size, batch_img_patch, batch_img_orig, model, input_noise, ds : DDIMScheduler, mask, start_ts = 300, patch_size = 16, lr = 0.01,
                  guidance_start = 2, guidance_end = 50, int_loss_weight = 1., int_loss_ts = 40, flip_oracle = [], init_oracle = [], 
                  has_anchor = False, partial_gt = None, partial_mask = None):
    patch_num = int(input_size / patch_size)
    batch_num = len(batch_img_patch)
    print(patch_num, batch_num)

    imgs = []
    pred_x0s = []
    sync_freq = 1
    scheduler = ds

    local_to_global = Rearrange("(b s1 s2) c x y -> b c (s1 x) (s2 y)", s1=patch_num, s2=patch_num)

    img = input_noise

    model = model.eval()
    ts = scheduler.timesteps[scheduler.timesteps <= start_ts]

    for i, t in enumerate(tqdm(ts)):
        img_copy = img.clone().detach()
        time = torch.full((batch_num,), t, dtype=torch.long).cuda()

        # prev i 10 - 49, 3 times
        if (i + 1) % sync_freq == 0 and i >= guidance_start and i <= guidance_end:
            print('optimization at t = ', t.data)
            step_num = 3

            if t >= ds.timesteps[15]:
              step_num = 10

            for j in range(step_num): # defualt 3

                img = img.requires_grad_()
                # predict noise residual
                noise_pred = model(img, time, batch_img_patch[0:batch_num].float().cuda())
                preds = ds.step(noise_pred, t, img)

                x0_guess = preds['pred_original_sample']
                # compute loss
                x0_guess_global = local_to_global(x0_guess)

                int_loss = integrability_loss(x0_guess_global[0][0], x0_guess_global[0][1], x0_guess_global[0][2], None).mean()
                loss_td, loss_lr = linear_angle_loss(x0_guess_global.cuda(), mask.cuda())

                boundary_loss = (loss_lr).mean() + (loss_td).mean()
                loss = boundary_loss

                if t <= ds.timesteps[int_loss_ts]:
                  loss = boundary_loss + int_loss * int_loss_weight

                if j == 0 or j % 100 == 0:
                    print(i, j, "loss: ", loss.item())

                norm_grad = grad(outputs = loss, inputs=img)[0]
                norm_grad.data = torch.clamp(norm_grad.data, -5, +5)

                # adjust lr for final runs
                if t <= ds.timesteps[40]: # 40
                  lr = lr / 1.5

                img_copy = img_copy - lr * norm_grad
                img = img_copy.clone().detach()

        # regular denoising step
        with torch.no_grad():
            noise_pred = model(img_copy.cuda(), time, batch_img_patch[0:batch_num].float().cuda())
            eta = 0
            preds = ds.step(noise_pred, t, img_copy, eta)
            img = preds['prev_sample']
            imgs.append(img.detach().cpu().numpy())
            pred_x0s.append(preds['pred_original_sample'].detach().cpu().numpy())

    return imgs, pred_x0s
  