import os
import torch
import logging
import hashlib
from base64 import b64encode, b64decode
from fastapi import FastAPI
from pydantic import BaseModel
import requests
import json
import multiprocessing
import numpy as np
import torch
import torchvision
import torch.distributed as dist

import wan
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, WAN_CONFIGS
from server.idm import IDM


class Request(BaseModel):
    prompt: str
    imgs: list
    num_conditional_frames: int = 1
    num_new_frames: int = 16
    seed: int = 1234
    num_sampling_step: int = 5
    guide_scale: float = 5.0
    password: str = ""
    return_imgs: bool = False
    clean_cache: bool = False


def sha256(text):
    h = hashlib.sha256()
    h.update(text.encode("utf-8"))
    return h.hexdigest()


def init():
    global wan_ti2v
    global ulysses_size
    global cfg
    global processor
    global mask_processor
    global idm
    cfg = WAN_CONFIGS["ti2v-5B"]

    print(f"Current PID: {os.getpid()}")
    rank = int(os.getenv("RANK", 0))
    world_size = int(os.getenv("WORLD_SIZE", 1))
    local_rank = int(os.getenv("LOCAL_RANK", 0))
    ulysses_size = int(os.getenv("ULYSSES_SIZE", 1))
    pt_dir = os.getenv("MODEL", None)
    idm_path = os.getenv("IDM", None)
    device = local_rank

    if world_size > 1:
        torch.cuda.set_device(local_rank)
        dist.init_process_group(
            backend="nccl",
            init_method="env://",
            rank=rank,
            world_size=world_size)

    processor = torchvision.transforms.Compose([
        torchvision.transforms.Resize((518, 518)),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    mask_processor = torchvision.transforms.Resize((736, 640))
    idm = IDM(model_name="mask", output_dim=14).to(device)
    if os.path.isfile(idm_path):
        loaded_dict = torch.load(idm_path, weights_only=False)
        idm.load_state_dict(loaded_dict["model_state_dict"])
        print("IDM loaded")
    idm.eval()

    wan_ti2v = wan.WanTI2VCausalServer(
        config=cfg,
        checkpoint_dir="./Wan2.2-TI2V-5B",
        pt_dir=pt_dir,
        device_id=device,
        rank=rank,
        t5_fsdp=False,
        dit_fsdp=False,
        use_sp=(ulysses_size > 1),
        t5_cpu=False,
        convert_model_dtype=True,
    )


def batch_tensor_to_jpeg_message(tensor):
    tensor = (tensor * 255).to(torch.uint8).cpu()
    jpeg_message_list = []
    for i in range(tensor.shape[0]):
        jpeg_tensor = torchvision.io.encode_jpeg(tensor[i])
        jpeg_message_list.append(b64encode(jpeg_tensor.numpy().tobytes()).decode("utf-8"))
    return jpeg_message_list


def idm_pred(request, imgs):
    global processor
    global mask_processor
    global idm
    return_imgs = request.return_imgs
    with torch.no_grad():
        actions, masks = idm(processor(imgs), return_mask=return_imgs)
    actions = json.dumps(actions.cpu().numpy().tolist())
    pred = {"actions": actions}
    if return_imgs:
        pred['imgs'] = batch_tensor_to_jpeg_message(imgs)
        masks = mask_processor(masks)
        pred['masks'] = batch_tensor_to_jpeg_message(torch.where(masks >= 0.5, imgs, 1))
    return pred


def get_pred(request):
    global ulysses_size
    global cfg
    if ulysses_size > 1:
        assert cfg.num_heads % ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{ulysses_size=}`."
    rank = int(os.getenv("RANK", 0))
    if dist.is_initialized():
        base_seed = [request.seed] if rank == 0 else [None]
        dist.broadcast_object_list(base_seed, src=0)

    frame_num = request.num_conditional_frames + request.num_new_frames
    imgs = wan_ti2v.generate(
        request.prompt,
        img=request.imgs,
        size=SIZE_CONFIGS["640*736"],
        max_area=MAX_AREA_CONFIGS["640*736"],
        frame_num=frame_num,
        num_conditional_frames=request.num_conditional_frames,
        clean_cache=request.clean_cache,
        shift=cfg.sample_shift,
        sample_solver='unipc',
        sampling_steps=request.num_sampling_step,
        guide_scale=request.guide_scale,
        seed=request.seed,
        offload_model=False
    )
    imgs = imgs[None].clamp(-1, 1)
    imgs = torch.stack([torchvision.utils.make_grid(u, nrow=8, normalize=True, value_range=(-1, 1)) for u in imgs.unbind(2)], dim=1).permute(1, 0, 2, 3) # [B, C, H, W]
    pred = idm_pred(request, imgs)
    return pred


api = FastAPI()
wan_ti2v = None
ulysses_size = None
cfg = None
idm = None
processor = None
mask_processor = None
init()


@api.post("/")
async def predict(request: Request):
    print("Request:", request.prompt, request.num_conditional_frames, request.num_new_frames, request.seed)
    if sha256(request.password) == "d43e76d9cad30d53805246aa72cc25b04ce2cbe6c7086b53ac6fb5028c48d307":
        if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
            workers = []
            headers = {
                "Content-Type": "application/json",
            }
            port = int(os.environ["WORKER_PORT"])
            num_processes = int(os.environ["WORLD_SIZE"])
            for i in range(1, num_processes):
                workers.append(multiprocessing.Process(target=requests.post, kwargs={"url": f"http://localhost:{port + i}", "headers": headers, "data": json.dumps(dict(request)), "verify": False}))
            for worker in workers:
                worker.start()
        pred = get_pred(request)
        if pred is not None:
            return pred
    else:
        return {}
