import inspect
from typing import Any, Callable, Dict, List, Optional, Union
from torch.utils.data import DataLoader, Subset, Dataset

import torch
import torch.nn.functional as F
from tqdm.auto import tqdm

from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline, StableDiffusionPipelineOutput, StableDiffusionSafetyChecker

from diffusers.configuration_utils import FrozenDict
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
    deprecate,
    logging,
    replace_example_docstring,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline


class LaionCaptions(Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices
    
    def __len__(self, ):
        return len(self.indices)
    
    def __getitem__(self, index):
        idx = self.indices[index]
        data = self.dataset['train'][idx]
        return {
            'caption': data['caption'],
            'TEXT': data['TEXT']
        }


def direct_merging(x, patch, lpos):
    # images = torch.from_numpy(np.array(images)).permute(0, 3, 1, 2).float()/255
    lr_patch = x[:, :, lpos[0]:lpos[0]+lpos[2], lpos[1]:lpos[1]+lpos[3]]
    lr_patch = F.interpolate(lr_patch, size=lr_patch.shape[2:], mode='bicubic')[0]

    hr_patch = F.interpolate(patch, size=(lpos[2], lpos[3]), mode='bicubic')
    x = x.clone()
    x[:, :, lpos[0]:lpos[0]+lpos[2], lpos[1]:lpos[1]+lpos[3]] = hr_patch.reshape(3, lpos[2], lpos[3])
    # fig, axes = plt.subplots(2, figsize=(10, 10))

    return x

def patch_smoothing(x, patch, lpos):
    # images = torch.from_numpy(np.array(images)).permute(0, 3, 1, 2).float()/255
    # axes[0].imshow(images[1].permute(1, 2, 0))
    # lr_patch = images[1:, :, lpos[0]:lpos[0]+lpos[2], lpos[1]:lpos[1]+lpos[3]]
    # lr_patch = F.interpolate(lr_patch, size=lr_patch.shape[2:], mode='bicubic')[0]
    # axes[2].imshow(lr_patch.permute(1, 2, 0))

    # hr_patch = F.interpolate(images[:1], size=(zoom_in_size, zoom_in_size), mode='bicubic')
    # images[1][:, lpos[0]:lpos[0]+lpos[2], lpos[1]:lpos[1]+lpos[3]] = hr_patch.reshape(3, zoom_in_size, zoom_in_size)

    mask = torch.zeros_like(x)
    mask[:, :, lpos[0]:lpos[0]+lpos[2], lpos[1]:lpos[1]+lpos[3]] = 1
    
    kernel_size = 31
    mask_weight = F.conv_transpose2d(mask[:, 0:1, :, :].detach(), x.new_ones([1, 3, kernel_size, kernel_size])/(kernel_size**2), bias=None, stride=1, padding=kernel_size//2)
    # mask_weight[mask_weight != 0] = 1 / mask_weight[mask_weight != 0]
    output = torch.zeros_like(x)

    hr_patch = F.interpolate(patch, size=(lpos[2], lpos[3]), mode='bicubic')
    output[:, :, lpos[0]:lpos[0]+lpos[2], lpos[1]:lpos[1]+lpos[3]] = hr_patch
    output[mask < 1] = x[mask < 1]
    # print(output.shape, mask_weight.shape)
    output = output * mask_weight + x * (1 - mask_weight)

    # mask = torch.zeros_like(x)
    # mask[:, :, lpos[0]:lpos[0]+lpos[2], lpos[1]:lpos[1]+lpos[3]] = 1
    # output[mask < 1] = x[mask < 1]

    return output, mask_weight
    


if __name__ == "__main__":
    import torch
    import datasets
    import numpy as np
    import os
    import os.path as osp
    from PIL import Image

    import matplotlib.pyplot as plt
    from datasets import load_dataset
    from datasets.utils import DownloadConfig
    from tqdm.auto import tqdm
    from torch.utils.data import DataLoader, Subset
    from torch.nn import DataParallel
    from accelerate import PartialState

    from pipe.anyres_pipe import *

    config = DownloadConfig()
    config.resume_download = True
    dataset = load_dataset("guangyil/laion-coco-aesthetic", revision="refs/convert/parquet", download_config=config)

    model_id = "stabilityai/stable-diffusion-2"
    # model_id = "runwayml/stable-diffusion-v1-5"
    pipe = AnyResStableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
    # pipe = pipe.to("cuda")

    prompts = []
    np.random.seed(1234)
    torch.manual_seed(1234)
    idxs = np.random.choice(len(dataset['train']), 20000).astype(int).tolist()
    idxs = idxs[:10000]

    sub_dataset = LaionCaptions(dataset, idxs)
    dataloader = DataLoader(sub_dataset, batch_size=4*6)
    # pipe.vae = DataParallel(pipe.vae, device_ids=list(range(8)))
    # pipe.unet = DataParallel(pipe.unet, device_ids=list(range(8)))
    # with open('prompt.txt', 'w') as f:
        # for data in tqdm(sub_dataset):
            # f.write(data['TEXT'] + '\n')
    # exit()

    distributed_state = PartialState()
    pipe.to(distributed_state.device)
    pipe.enable_xformers_memory_efficient_attention()
    # pipe.enable_vae_slicing()
    rank = distributed_state.process_index

    sd_dir = 'samples/sd_2_global'
    ar_dir = 'samples/anyres_sd_2_global_no_vae'
    sd_dir_l = 'samples/sd_2_local'
    ar_dir_l = 'samples/anyres_sd_2_local_no_vae'
    os.makedirs(sd_dir, exist_ok=True)
    os.makedirs(ar_dir, exist_ok=True)
    os.makedirs(sd_dir_l, exist_ok=True)
    os.makedirs(ar_dir_l, exist_ok=True)
    idx = 0
    for batch in tqdm(dataloader, disable=(rank==0)):
        prompts = batch['TEXT']
        # print(prompts)
        size = len(prompts)

        image_size = 768
        zoom_in_size = image_size // 2
        pad = (image_size - zoom_in_size) // 2
        # center patch
        lpos = [image_size//4, image_size//4, zoom_in_size, zoom_in_size]
        # lpos = [0, 0, 512, 512]
        # lpos = [0, 0, 768, 768]
        with torch.inference_mode(), distributed_state.split_between_processes(prompts) as prompt:
            images = pipe(prompt, K=25, lpos=lpos, noise_scale=False, vae_sr=True, guided_lr=False).images

            imgs = torch.from_numpy(np.array(images)).float().permute(0, 3, 1, 2) / 255
            # out_smooth, mask_weight = patch_smoothing(imgs[len(prompt):], imgs[:len(prompt)], lpos=lpos)

            offset = len(prompt) * rank
            
            # save sd results
            # for ii, img in enumerate(range(len(prompt))):
            #     lr_patch = imgs[len(prompt)*2+ii:, :, lpos[0]:lpos[0]+lpos[2], lpos[1]:lpos[1]+lpos[3]]
            #     lr_patch = F.interpolate(lr_patch, size=imgs.shape[2:], mode='bicubic')
            #     # img.save(f'samples/sd_2_global/{idx+ii+offset:05d}.png')
            #     images[len(prompt)*2+ii].save(osp.join(sd_dir, f'{idx+ii+offset:05d}.png'))
            #     lr_patch = Image.fromarray((torch.clip(lr_patch[0], min=0, max=1).numpy().transpose(1, 2, 0) * 255).astype(np.uint8))
            #     lr_patch.save(osp.join(sd_dir_l, f'{idx+ii+offset:05d}.png'))

            # save anyres results
            for ii, img in enumerate(images[:len(prompt)]):
                # img = Image.fromarray((torch.clip(out_smooth[ii], min=0, max=1).numpy().transpose(1, 2, 0) * 255).astype(np.uint8))
                # img.save(osp.join(ar_dir, f'{idx+ii+offset:05d}.png'))
                images[len(prompt)+ii].save(osp.join(ar_dir, f'{idx+ii+offset:05d}.png'))
                # image_sr = images[ii].resize((512, 512))
                images[ii].save(osp.join(ar_dir_l, f'{idx+ii+offset:05d}.png'))
            
            idx += size

