import inspect
from typing import Any, Callable, Dict, List, Optional, Union
from torch.utils.data import DataLoader, Subset, Dataset

import torch
import torch.nn.functional as F
from tqdm.auto import tqdm

from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline, StableDiffusionPipelineOutput, StableDiffusionSafetyChecker

from diffusers.configuration_utils import FrozenDict
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
    deprecate,
    logging,
    replace_example_docstring,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline


class LaionCaptions(Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices
    
    def __len__(self, ):
        return len(self.indices)
    
    def __getitem__(self, index):
        idx = self.indices[index]
        data = self.dataset['train'][idx]
        return {
            'caption': data['caption'],
            'TEXT': data['TEXT']
        }


def direct_merging(x, patch, lpos):
    # images = torch.from_numpy(np.array(images)).permute(0, 3, 1, 2).float()/255
    lr_patch = x[:, :, lpos[0]:lpos[0]+lpos[2], lpos[1]:lpos[1]+lpos[3]]
    lr_patch = F.interpolate(lr_patch, size=lr_patch.shape[2:], mode='bicubic')[0]

    hr_patch = F.interpolate(patch, size=(lpos[2], lpos[3]), mode='bicubic')
    x = x.clone()
    x[:, :, lpos[0]:lpos[0]+lpos[2], lpos[1]:lpos[1]+lpos[3]] = hr_patch.reshape(3, lpos[2], lpos[3])
    # fig, axes = plt.subplots(2, figsize=(10, 10))

    return x

def patch_smoothing(x, patch, lpos):
    # images = torch.from_numpy(np.array(images)).permute(0, 3, 1, 2).float()/255
    # axes[0].imshow(images[1].permute(1, 2, 0))
    # lr_patch = images[1:, :, lpos[0]:lpos[0]+lpos[2], lpos[1]:lpos[1]+lpos[3]]
    # lr_patch = F.interpolate(lr_patch, size=lr_patch.shape[2:], mode='bicubic')[0]
    # axes[2].imshow(lr_patch.permute(1, 2, 0))

    # hr_patch = F.interpolate(images[:1], size=(zoom_in_size, zoom_in_size), mode='bicubic')
    # images[1][:, lpos[0]:lpos[0]+lpos[2], lpos[1]:lpos[1]+lpos[3]] = hr_patch.reshape(3, zoom_in_size, zoom_in_size)

    mask = torch.zeros_like(x)
    mask[:, :, lpos[0]:lpos[0]+lpos[2], lpos[1]:lpos[1]+lpos[3]] = 1
    
    kernel_size = 31
    mask_weight = F.conv_transpose2d(mask[:, 0:1, :, :].detach(), x.new_ones([1, 3, kernel_size, kernel_size])/(kernel_size**2), bias=None, stride=1, padding=kernel_size//2)
    # mask_weight[mask_weight != 0] = 1 / mask_weight[mask_weight != 0]
    output = torch.zeros_like(x)

    hr_patch = F.interpolate(patch, size=(lpos[2], lpos[3]), mode='bicubic')
    output[:, :, lpos[0]:lpos[0]+lpos[2], lpos[1]:lpos[1]+lpos[3]] = hr_patch
    output[mask < 1] = x[mask < 1]
    # print(output.shape, mask_weight.shape)
    output = output * mask_weight + x * (1 - mask_weight)

    # mask = torch.zeros_like(x)
    # mask[:, :, lpos[0]:lpos[0]+lpos[2], lpos[1]:lpos[1]+lpos[3]] = 1
    # output[mask < 1] = x[mask < 1]

    return output, mask_weight
    


if __name__ == "__main__":
    import torch
    import datasets
    import numpy as np
    import os
    import os.path as osp
    from PIL import Image

    import matplotlib.pyplot as plt
    from datasets import load_dataset
    from datasets.utils import DownloadConfig
    from tqdm.auto import tqdm
    from torch.utils.data import DataLoader, Subset
    from torch.nn import DataParallel
    from accelerate import PartialState

    from pipe.anyres_controlnet_pipe import *

    import torchvision
    from torchvision.datasets import ImageFolder
    from torch.utils.data import DataLoader, Dataset, Subset

    from dataloader.hamer_dataset import VideoDataset
    
    import json

    # transform = torchvision.transforms.Compose([
    #     torchvision.transforms.Resize(512),
    #     torchvision.transforms.CenterCrop((512, 512)),
    #     torchvision.transforms.ToTensor(),
    # ])

    # dataset = ImageFolder(root='./samples/gt', transform=transform)

    # dataloader = DataLoader(dataset=dataset, batch_size=8)

    controlnet = ControlNetModel.from_pretrained(
        "lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16
    )
    pipe = AnyResStableDiffusionControlNetPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
    )
    pipe.enable_xformers_memory_efficient_attention()
    # pipe.enable_model_cpu_offload()


    config = DownloadConfig()
    config.resume_download = True
    # dataset = load_dataset("guangyil/laion-coco-aesthetic", revision="refs/convert/parquet", download_config=config)
    dataset = VideoDataset('./datasets/ubc_fashion/', split='val', sample_size=(768, 512), text_prompt=True, image_finetune=True, local_body='hands')

    # model_id = "stabilityai/stable-diffusion-2"
    # model_id = "runwayml/stable-diffusion-v1-5"
    # pipe = AnyResStableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
    # pipe = pipe.to("cuda")

    prompts = []
    np.random.seed(1234)
    torch.manual_seed(1234)
    idxs = np.random.choice(len(dataset), 5000).astype(int).tolist()
    idxs = idxs[:5000]

    sub_dataset = Subset(dataset, idxs)
    dataloader = DataLoader(sub_dataset, batch_size=8*7)
    # pipe.vae = DataParallel(pipe.vae, device_ids=list(range(8)))
    # pipe.unet = DataParallel(pipe.unet, device_ids=list(range(8)))
    # with open('prompt.txt', 'w') as f:
        # for data in tqdm(sub_dataset):
            # f.write(data['TEXT'] + '\n')
    # exit()

    distributed_state = PartialState()
    pipe.to(distributed_state.device)
    pipe.enable_xformers_memory_efficient_attention()
    # pipe.enable_vae_slicing()
    rank = distributed_state.process_index

    sd_dir = 'human_hand_20/sd_global'
    ar_dir = 'human_hand_20/anyres_sd_global'
    sd_dir_l = 'human_hand_20/sd_local'
    ar_dir_l = 'human_hand_20/anyres_sd_local'
    os.makedirs(sd_dir, exist_ok=True)
    os.makedirs(ar_dir, exist_ok=True)
    os.makedirs(sd_dir_l, exist_ok=True)
    os.makedirs(ar_dir_l, exist_ok=True)
    idx = 0

    for batch in tqdm(dataloader, disable=(rank==0)):
        # poses = batch['pose']
        # local_poses = batch['lpose']
        # print(prompts)
        size = batch['pose'].shape[0]

        # image_size = 768
        # zoom_in_size = image_size // 2
        # pad = (image_size - zoom_in_size) // 2
        # center patch
        # lpos = [image_size//4, image_size//4, zoom_in_size, zoom_in_size]
        # lpos = [0, 0, 512, 512]
        # lpos = [0, 0, 768, 768]
        with torch.inference_mode(), distributed_state.split_between_processes(batch) as prompt:
            poses = prompt['pose']
            lpos = prompt['lpos']
            prompt = prompt['text']
            # print(prompt)
            images = pipe(prompt, poses, num_inference_steps=50, K=20, lpos=lpos).images

            imgs = torch.from_numpy(np.array(images)).float().permute(0, 3, 1, 2) / 255

            offset = len(prompt) * rank

            
            # save sd results
            for ii, img in enumerate(range(len(prompt))):
                pos = lpos[ii]
                lr_patch = imgs[len(prompt)+ii:, :, pos[0]:pos[0]+pos[2], pos[1]:pos[1]+pos[3]]
                lr_patch = F.interpolate(lr_patch, size=imgs.shape[2:], mode='bicubic')
                # img.save(f'samples/sd_2_global/{idx+ii+offset:05d}.png')
                images[len(prompt)+ii].save(osp.join(sd_dir, f'{idx+ii+offset:05d}.png'))
                lr_patch = Image.fromarray((torch.clip(lr_patch[0], min=0, max=1).numpy().transpose(1, 2, 0) * 255).astype(np.uint8))
                lr_patch.save(osp.join(sd_dir_l, f'{idx+ii+offset:05d}.png'))

            # save anyres results
            for ii, img in enumerate(images[:len(prompt)]):
                out_smooth, mask_weight = patch_smoothing(imgs[len(prompt)+ii].unsqueeze(0), imgs[ii].unsqueeze(0), lpos=lpos[ii])
                img = Image.fromarray((torch.clip(out_smooth[0], min=0, max=1).numpy().transpose(1, 2, 0) * 255).astype(np.uint8))
                img.save(osp.join(ar_dir, f'{idx+ii+offset:05d}.png'))
                # images[len(prompt)+ii].save(osp.join(ar_dir, f'{idx+ii+offset:05d}.png'))
                # image_sr = images[ii].resize((512, 512))
                images[ii].save(osp.join(ar_dir_l, f'{idx+ii+offset:05d}.png'))
            
            idx += size


