import os
import numpy as np
from copy import deepcopy
from einops import rearrange
from types import SimpleNamespace

from .attn_utils import register_attention_editor_diffusers, MutualSelfAttentionControl
from tqdm import tqdm
from diffusers import DDIMScheduler, AutoencoderKL
from drag_pipeline import DragPipeline
import torch
import torch.nn.functional as F

from pytorch_lightning import seed_everything

# from .drag_utils import drag_diffusion_update
import torchvision.transforms.functional as Fu
import datetime

# from .shift_test import shift_matrix,copy_past,paint_past
torch.set_printoptions(profile="full")

from .unet_drag.unet_2d_condition import UNet2DConditionModel  # for memory
from .drag_utils import drag_image
from .continuous_drag import drag_optical_flow_ratio_interp

def preprocess_image(image,
                     device,
                     dtype=torch.float32):
    image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1]
    image = rearrange(image, "h w c -> 1 c h w")
    image = image.to(device, dtype)
    return image

def resize2latent(x, new_size):
    return Fu.resize(x, new_size)

def interpolation(x):
    assert x.dim() == 4, "Input tensor x should have shape (1, C, N, M)"
    batch_size, channels, N, M = x.shape 

    for b in range(batch_size):
        zero_positions = (x[b, 0] == 0)

        for i in range(N):
            for j in range(M):
                if zero_positions[i, j]:
                    values = []  
                    weights = [] 

                    for k in range(1, j + 1):
                        if j - k >= 0 and x[b, 0, i, j - k] != 0:
                            values.append(x[b, :, i, j - k])
                            weights.append(1 / k)
                            break

                    for k in range(1, M - j):
                        if j + k < M and x[b, 0, i, j + k] != 0:
                            values.append(x[b, :, i, j + k])
                            weights.append(1 / k)
                            break

                    for k in range(1, i + 1):
                        if i - k >= 0 and x[b, 0, i - k, j] != 0:
                            values.append(x[b, :, i - k, j])
                            weights.append(1 / k)
                            break

                    for k in range(1, N - i):
                        if i + k < N and x[b, 0, i + k, j] != 0:
                            values.append(x[b, :, i + k, j])
                            weights.append(1 / k)
                            break

                    if weights:
                        total_weight = sum(weights)
                        interpolated_value = sum(w * v for w, v in zip(weights, values)) / total_weight
                        x[b, :, i, j] = interpolated_value

    return x

def edit_latent(source_image, invert_code, handle_points, target_points, mask_cp_handle, 
                lambda_mix=None, gamma_ratio=0.5, upper_scale=1.5, lower_scale=0.5, alpha=2.0, beta=2.0, 
                test_lambda='linear', test_space_weight='linear', test_depth_weight='noauto', test_fusion='amplitude', fill_mode='interpolation'):
    device = invert_code.device
    B, C, H, W = invert_code.shape

    yy, xx = torch.meshgrid(
        torch.linspace(-1, 1, H, device=device),
        torch.linspace(-1, 1, W, device=device),
        indexing='ij'
    )
    grid = torch.stack((xx, yy), dim=-1)  # [H, W, 2] in (x, y)


    optical_flow = drag_image(source_image, mask_cp_handle, handle_points, target_points, 
                                     lambda_mix=lambda_mix, gamma_ratio=gamma_ratio,
                                     upper_scale=upper_scale, lower_scale=lower_scale, alpha=alpha, beta=beta, test_lambda=test_lambda,
                                     test_space_weight=test_space_weight,test_depth_weight=test_depth_weight,test_fusion=test_fusion)
    # norm_flow = torch.zeros_like(optical_flow)  # [1, 2, H, W]
    # norm_flow[..., 0] = optical_flow[..., 0] / ((W - 1) / 2)  # dx / width
    # norm_flow[..., 1] = optical_flow[..., 1] / ((H - 1) / 2)  # dy / height
    # # import pdb; pdb.set_trace()


    # warped_grid = grid + norm_flow  # [H, W, 2]
    

    # warped_grid = warped_grid.unsqueeze(0).repeat(B, 1, 1, 1)  # [B, H, W, 2]
    

    # warped_grid = warped_grid.to(dtype=invert_code.dtype)
    # invert_code_d = F.grid_sample(invert_code, warped_grid, mode='bilinear', padding_mode='zeros', align_corners=True)
    
    invert_code_d, mask = drag_optical_flow_ratio_interp(invert_code, optical_flow, mask_cp_handle)
    
    return invert_code_d, mask
    
def run_drag(model,
             ori_image,
        mask,
        prompt,
        points,
        inversion_strength,
        model_path,
        vae_path,
        start_step,
        start_layer,
        n_inference_step,
        task_cat,
        lambda_mix=None,
        gamma_ratio=0.5,
        upper_scale=1.5,
        lower_scale=0.5,
        alpha=2.0,
        beta=2.0,
        test_lambda='linear',
        test_space_weight='linear',
        test_depth_weight='noauto',
        test_fusion='amplitude',
        fill_mode='interpolation',
        lora_path=None,
        *,
        device="cuda" if torch.cuda.is_available() else "cpu",
        save_dir="./results",):
    if model is None:
        scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012,
                              beta_schedule="scaled_linear", clip_sample=False,
                              set_alpha_to_one=False, steps_offset=1)
        model = DragPipeline.from_pretrained(model_path, scheduler=scheduler, torch_dtype=torch.float16)
        unet = UNet2DConditionModel.from_pretrained(
                         "SimianLuo/LCM_Dreamshaper_v7",
                        subfolder="unet",
                        torch_dtype=torch.float16,)
        model.unet = unet
        
        model.modify_unet_forward()
        if vae_path != "default":
            model.vae = AutoencoderKL.from_pretrained(
                vae_path
            ).to(model.vae.device, model.vae.dtype)
        model.enable_model_cpu_offload(device=device)
    seed = 42 # random seed used by a lot of people for unknown reason
    seed_everything(seed)
    
    args = SimpleNamespace()
    # print("prompt:", prompt)
    args.prompt = prompt
    args.n_inference_step = int(n_inference_step) #50
    args.n_actual_inference_step = round(inversion_strength * args.n_inference_step)
    args.guidance_scale = 1.0

    args.unet_feature_idx = [3]

    full_h, full_w = ori_image.shape[:2]
    args.sup_res_h = int(0.5*full_h)
    args.sup_res_w = int(0.5*full_w)

    print(args)
    
    mask = torch.from_numpy(mask).float() / 255.
    mask[mask > 0.0] = 1.0
    mask = rearrange(mask, "h w -> 1 1 h w").to(device)
    mask = resize2latent(mask, (int(full_h/8), int(full_w/8)))
    
    handle_points = []
    target_points = []
    for idx, point in enumerate(points):
        cur_point = torch.tensor([point[0]/full_w*args.sup_res_w, point[1]/full_h*args.sup_res_h])
        cur_point = torch.round(cur_point/4)
        if idx % 2 == 0:
            handle_points.append(cur_point)
        else:
            target_points.append(cur_point)
    
    # generation
    text_embeddings = model.get_text_embeddings(prompt)
    source_image = preprocess_image(ori_image, device, dtype=torch.float16)
    
    if lora_path is not None:
        print("applying lora: " + lora_path)
        model.unet.load_attn_procs(lora_path)
    
    # invert_code = model.invert(source_image,
    #                            prompt,
    #                            text_embeddings=text_embeddings,
    #                            guidance_scale=args.guidance_scale,
    #                            num_inference_steps=args.n_inference_step,
    #                            num_actual_inference_steps=args.n_actual_inference_step,
    #                            )
    # update 919
    invert_code, latent_list, t_list, iter_cur_list, text_embeddings = model.invert(source_image,
                               prompt,
                               text_embeddings=text_embeddings,
                               guidance_scale=args.guidance_scale,
                               num_inference_steps=args.n_inference_step,
                               num_actual_inference_steps=args.n_actual_inference_step,
                               return_intermediates=True
                               )
    # invert_code = latent_list[-4]
    # add_timestep = t_list[-3]
    # iter_cur = iter_cur_list[-3]
    # del latent_list, t_list, iter_cur_list
    
    updated_code, mask = edit_latent(ori_image,
                                            invert_code=invert_code,
                                            handle_points=handle_points,
                                            target_points=target_points,
                                            mask_cp_handle=mask,
                                            fill_mode=fill_mode,
                                            lambda_mix=lambda_mix,
                                            gamma_ratio=gamma_ratio,
                                            upper_scale=upper_scale,
                                            lower_scale=lower_scale,
                                            alpha=alpha,
                                            beta=beta,
                                            test_lambda=test_lambda,
                                            test_space_weight=test_space_weight,
                                            test_depth_weight=test_depth_weight,
                                            test_fusion=test_fusion,
                                            )
    
                
    torch.cuda.empty_cache()
    model.scheduler.set_timesteps(args.n_inference_step)
    text_embeddings = text_embeddings.half()
    model.unet = model.unet.half()
    updated_code = updated_code.half()
    invert_code = invert_code.half()
    editor = MutualSelfAttentionControl(start_step=start_step,
                                        start_layer=start_layer,
                                        total_steps=args.n_inference_step,
                                        guidance_scale=args.guidance_scale)
    mask = mask.unsqueeze(0).to(torch.bool)
    gen_image = model(
            prompt=args.prompt,
            text_embeddings=torch.cat([text_embeddings, text_embeddings], dim=0),
            batch_size=2,
            latents=torch.cat([updated_code, updated_code], dim=0),
            guidance_scale=args.guidance_scale,
            num_inference_steps=args.n_inference_step,
            num_actual_inference_steps=args.n_actual_inference_step,
            eta=0.4,
            mask=mask.repeat(2,1,1,1) if mask is not None else None
            )[1].unsqueeze(dim=0)
    gen_image = F.interpolate(gen_image, (full_h, full_w), mode='bilinear')
    out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0]
    out_image = (out_image * 255).astype(np.uint8)
    return out_image
    
