import argparse
import os
import sys
from datetime import datetime
from pathlib import Path
from typing import List

current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
sys.path.append(parent_dir)

import numpy as np
import torch
import torchvision
from diffusers import AutoencoderKL, DDIMScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
from einops import repeat
from omegaconf import OmegaConf
from PIL import Image
from torchvision import transforms
from transformers import CLIPVisionModelWithProjection

from src.models.pose_guider import PoseGuider
from src.models.unet_2d_condition import UNet2DConditionModel
from src.models.unet_3d import UNet3DConditionModel
from src.pipelines.pipeline_parsing2img import Parsing2ImagePipeline
from src.utils.util import get_fps, read_frames, save_videos_grid
from FaRL import facer


parsing_name = 'parsing.png'
panc_name = 'parsing_align_no_contour.png'

# Clip+ctTlUlOBXpo+P0+C1+F550-762_Clip+ctYKkABcxcA+P0+C0+F659-885
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default='./configs/inference/inference_stage1.yaml')
    parser.add_argument("--reference_img", default='data/jpgs/Clip+cyFHdWb3IF8+P2+C0+F5994-6165/00000000.jpg')
    parser.add_argument("--target_img", default='data/jpgs/Clip+czDnpfNjrH4+P0+C1+F25-269/00000242.jpg')
    # parser.add_argument("--reference_img", default='data/jpgs/Clip+ctTlUlOBXpo+P0+C1+F550-762/00000000.jpg')
    # parser.add_argument("--target_img", default='data/jpgs/Clip+ctYKkABcxcA+P0+C0+F659-885/00000000.jpg')
    # parser.add_argument("--reference_img", default='data/jpgs/Clip+dBPBI71L1lk+P1+C0+F14450-14656/00000000.jpg')
    # parser.add_argument("--target_img", default='data/jpgs/Clip+dBPBI71L1lk+P1+C0+F14450-14656/00000057.jpg')
    parser.add_argument("--save_dir", default=None)
    parser.add_argument("-W", type=int, default=512)
    parser.add_argument("-H", type=int, default=512)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--cfg", type=float, default=3.5)
    parser.add_argument("--steps", type=int, default=30)
    args = parser.parse_args()

    return args

def only_one_face(faces):
    if faces['rects'].size(0) == 1:
        return faces
    max = 0
    max_id = -1
    for i in range(faces['rects'].size(0)):
        x1, y1, x2, y2 = faces['rects'][i]
        # print(x1, y1, x2, y2)
        area = (x2 - x1) * (y2 - y1)
        if area > max:
            max = area
            max_id = i
    for key, value in faces.items():
        faces[key] = value[max_id].unsqueeze(0)

def get_reference_parsing(jpg_path, save_folder, face_detector, face_parser, face_aligner, device):
    print(f'get {jpg_path} parsings')

    image = facer.hwc2bchw(facer.read_hwc(jpg_path)).to(device=device)  # image: 1 x 3 x h x w
    with torch.inference_mode():
        faces = face_detector(image)
    if faces['rects'].size(0) == 0:
        raise ValueError(f'{jpg_path} has no face')
    only_one_face(faces)
    with torch.inference_mode():
        parsings = face_parser(image, faces)
        alignments = face_aligner(image, faces)
    seg_logits = parsings['seg']['logits']
    seg_probs = seg_logits.softmax(dim=1)  # nfaces x nclasses x h x w
    n_classes = seg_probs.size(1)
    vis_seg_probs = seg_probs.argmax(dim=1).float()/n_classes*255
    vis_img = vis_seg_probs.sum(0, keepdim=True)
    # save parsing
    parse_img = facer.get_bhw(vis_img)
    pimage = Image.fromarray(parse_img.cpu().numpy())
    pimage.save(os.path.join(save_folder, parsing_name))

def get_target_parsing(jpg_path, save_folder, face_detector, face_parser, face_aligner, device):
    print(f'get {jpg_path} parsings')

    image = facer.hwc2bchw(facer.read_hwc(jpg_path)).to(device=device)  # image: 1 x 3 x h x w
    with torch.inference_mode():
        faces = face_detector(image)
    if faces['rects'].size(0) == 0:
        raise ValueError(f'{jpg_path} has no face')
    only_one_face(faces)
    with torch.inference_mode():
        parsings = face_parser(image, faces)
        alignments = face_aligner(image, faces)
    seg_logits = parsings['seg']['logits']
    seg_probs = seg_logits.softmax(dim=1)  # nfaces x nclasses x h x w
    n_classes = seg_probs.size(1)
    vis_seg_probs = seg_probs.argmax(dim=1).float()/n_classes*255
    vis_img = vis_seg_probs.sum(0, keepdim=True)
    # save parsing_align_no_contour
    parse_img = facer.get_bhw_no_contour(vis_img)
    img = parse_img
    for pts in alignments['alignment']:
        # 之前的不用 color, 新训练的需要
        img = facer.draw_landmarks_only_eyes(img, None, pts.cpu().numpy(), color=(105, 105, 105))
    pimage = Image.fromarray(img)
    pimage.save(os.path.join(save_folder, panc_name))

def main():
    args = parse_args()

    config = OmegaConf.load(args.config)

    if config.weight_dtype == "fp16":
        weight_dtype = torch.float16
    else:
        weight_dtype = torch.float32

    print('create model')
    vae = AutoencoderKL.from_pretrained(
        config.pretrained_vae_path,
    ).to("cuda", dtype=weight_dtype)

    reference_unet = UNet2DConditionModel.from_pretrained(
        config.pretrained_base_model_path,
        subfolder="unet",
    ).to(device="cuda")

    denoising_unet = UNet3DConditionModel.from_pretrained_2d(
        config.pretrained_base_model_path,
        "",
        subfolder="unet",
        unet_additional_kwargs={
            "use_motion_module": False,
            "unet_use_temporal_attention": False,
        },
    ).to(device="cuda")

    pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
        device="cuda"
    )

    image_enc = CLIPVisionModelWithProjection.from_pretrained(
        config.image_encoder_path
    ).to(dtype=weight_dtype, device="cuda")

    sched_kwargs = OmegaConf.to_container(config.noise_scheduler_kwargs)
    scheduler = DDIMScheduler(**sched_kwargs)

    generator = torch.manual_seed(args.seed)

    width, height = args.W, args.H

    print('load model')
    # load pretrained weights
    denoising_unet.load_state_dict(
        torch.load(config.denoising_unet_path, map_location="cpu"),
        strict=False,
    )
    reference_unet.load_state_dict(
        torch.load(config.reference_unet_path, map_location="cpu"),
    )
    pose_guider.load_state_dict(
        torch.load(config.pose_guider_path, map_location="cpu"),
    )

    pipe = Parsing2ImagePipeline(
        vae=vae,
        image_encoder=image_enc,
        reference_unet=reference_unet,
        denoising_unet=denoising_unet,
        pose_guider=pose_guider,
        scheduler=scheduler,
    )
    pipe = pipe.to("cuda", dtype=weight_dtype)

    if args.save_dir == None:
        date_str = datetime.now().strftime("%Y%m%d")
        time_str = datetime.now().strftime("%H%M")
        save_dir_name = f"{time_str}--seed_{args.seed}-{args.W}x{args.H}"
        save_dir = Path(f"output/{date_str}/{save_dir_name}")
        save_dir.mkdir(exist_ok=True, parents=True)
    else:
        save_dir = args.save_dir

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    face_detector = facer.face_detector('retinaface/mobilenet', device=device)
    face_parser = facer.face_parser('farl/lapa/448', device=device) # optional "farl/celebm/448"
    face_aligner = facer.face_aligner('farl/wflw/448', device=device)
    get_reference_parsing(args.reference_img, save_dir, face_detector, face_parser, face_aligner, device)
    get_target_parsing(args.target_img, save_dir, face_detector, face_parser, face_aligner, device)

    # inference
    ref_img_pil = Image.open(args.reference_img)
    ref_pose_pil = Image.open(os.path.join(save_dir, parsing_name))
    tgt_pose_pil = Image.open(os.path.join(save_dir, panc_name))

    # print(ref_img_pil.mode, ref_img_pil.size)
    # print(ref_pose_pil.mode, ref_pose_pil.size)
    # print(tgt_pose_pil.mode, tgt_pose_pil.size)

    image = pipe(
        ref_img_pil,
        ref_pose_pil,
        tgt_pose_pil,
        width,
        height,
        20,
        3.5,
        generator=generator,
    ).images
    image = image[0, :, 0].permute(1, 2, 0).cpu().numpy()
    res_image_pil = Image.fromarray((image * 255).astype(np.uint8))
    w, h = res_image_pil.size
    canvas = Image.new("RGB", (w * 3, h), "white")
    ref_img_pil = ref_img_pil.resize((w, h))
    tgt_pose_pil = tgt_pose_pil.resize((w, h))
    canvas.paste(ref_img_pil, (0, 0))
    canvas.paste(tgt_pose_pil, (w, 0))
    canvas.paste(res_image_pil, (w * 2, 0))

    pil_images = []
    ref_name = 'test'
    pose_name = 'test'
    pil_images.append({
        "name": f'{ref_name}_{pose_name}', 
        "img": canvas
    })

    for i, image_dict in enumerate(pil_images):
        name = image_dict["name"]
        img = image_dict["img"]
        out_file = Path(
            # f'{name}.png'
            os.path.join(save_dir, name + '.png')
        )
        img.save(out_file)



if __name__ == "__main__":
    main()
