import io
import os
import sys
import torch
from diffusers.models import AutoencoderKL
from PIL import Image
import torchvision.transforms as transforms
import numpy as np

import gradio as gr
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import StreamingResponse
import uvicorn
from gradio import networking
import secrets # Leveraging Gradio's secrets module

# adjust import paths
sys.path.append(os.path.join(os.path.dirname(__file__), 'aligner'))
from dense_aligner import ClipToLatentAligner
sys.path.append(os.path.dirname(__file__))
from vision_encoder_wrapper import VisionTransformerWrapper
torch.manual_seed(42)
import random
random.seed(42)
np.random.seed(42)
torch.cuda.manual_seed_all(42)  # For multi-GPU setups

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Configuration
# VISION_MODEL = "google/gemma-3-12b-it"
# VISION_MODEL = "openclip/ViT-L-14"
VISION_MODEL = "Qwen/Qwen2.5-VL-7B-Instruct"
CHECKPOINT_PATH = "REPLACE_WITH_YOUR_CHECKPOINT_PATH"
IMAGE_SIZE = 224
GRID_SIZE = IMAGE_SIZE // 8
LAYERS = 12
FEATURE_DIM = 1280 # gemma: 1152, openclip: 1024, qwen2.5: 1280

# Load models once
def load_models(): 
    vit_feature_only = 1 #os.environ.get('VIT_FEATURES_ONLY', '0') == '1'
    vision_encoder = VisionTransformerWrapper(VISION_MODEL, "/path/to/your/ckpt", IMAGE_SIZE, vit_feature_only)
    vision_encoder.move_to(DEVICE)

    vae_ref = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(DEVICE)
    vae_ref.eval()

    aligner_net = ClipToLatentAligner(None, FEATURE_DIM, 512, GRID_SIZE, LAYERS).to(DEVICE)
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    state_dict = {k.replace('aligner_net.', ''): v for k, v in checkpoint['state_dict'].items()}
    aligner_net.load_state_dict(state_dict)
    aligner_net.eval()

    return vision_encoder, aligner_net, vae_ref

vision_encoder, aligner_net, vae_ref = load_models()
img_transform = vision_encoder.image_transform

# Core reconstruction logic
def decode_latent(latent_tensor):
    decoded = vae_ref.decode(latent_tensor).sample
    tensor = (decoded.squeeze(0) * 0.5 + 0.5).clamp(0, 1)
    return transforms.ToPILImage()(tensor.cpu())

def reconstruct_from_image(img: Image.Image) -> Image.Image:
    if img is None:
        return None
    inp = img_transform(img).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        tokens = vision_encoder.encode_image(inp)
        mask = torch.zeros((tokens.size(0), tokens.size(1)), dtype=torch.bool).to(DEVICE)
        _, latent_data = aligner_net.encode(tokens, mask)
        latent = latent_data.latent_dist.mode()
        rec_img = decode_latent(latent)
    return rec_img

def reconstruct_from_tensor(tensor: torch.Tensor) -> Image.Image:
    if tensor is None:
        return None
    inp = tensor.to(DEVICE).to(torch.float32)
    # ensure batch dim
    if inp.ndim < 3:
        inp = inp.unsqueeze(0)
    with torch.no_grad():
        mask = torch.zeros(inp.shape[:2], dtype=torch.bool).to(DEVICE)
        _, latent_data = aligner_net.encode(inp, mask)
        latent = latent_data.latent_dist.mode()
        return decode_latent(latent)

# FastAPI app
app = FastAPI()

@app.post('/reconstruct_image')
async def api_reconstruct_image(file: UploadFile = File(...)):
    img = None
    try:
        content = await file.read()
        img = Image.open(io.BytesIO(content)).convert("RGB")
    except Exception:
        raise HTTPException(status_code=400, detail='Invalid image file')
    out = reconstruct_from_image(img)
    buf = io.BytesIO()
    out.save(buf, format='PNG')
    buf.seek(0)
    return StreamingResponse(buf, media_type='image/png')

@app.post('/reconstruct_tensor')
async def api_reconstruct_tensor(file: UploadFile = File(...)):
    try:
        data = torch.load(io.BytesIO(await file.read()), map_location=DEVICE)
        tensor = data['tensor'] if isinstance(data, dict) and 'tensor' in data else data
        if not isinstance(tensor, torch.Tensor):
            raise ValueError
    except Exception:
        raise HTTPException(status_code=400, detail='Invalid tensor file')
    out = reconstruct_from_tensor(tensor)
    buf = io.BytesIO()
    out.save(buf, format='PNG')
    buf.seek(0)
    return StreamingResponse(buf, media_type='image/png')

# Gradio UI
demo = gr.Blocks()
with demo:
    gr.Markdown('# Align+VAE Reconstruction')
    with gr.Tab('Image Input'):
        inp_img = gr.Image(type='pil')
        out_img = gr.Image(type='pil')
        inp_img.change(fn=reconstruct_from_image, inputs=inp_img, outputs=out_img)
    with gr.Tab('Tensor Input (.pt)'):
        inp_file = gr.File(file_types=['.pt'])
        out_tensor_img = gr.Image(type='pil')
        def _gr_from_tensor(f):
            if f is None:
                return None
            data = torch.load(f.name, map_location=DEVICE)
            tensor = data['tensor'] if isinstance(data, dict) and 'tensor' in data else data
            if not isinstance(tensor, torch.Tensor):
                return None
            return reconstruct_from_tensor(tensor)
        inp_file.change(fn=_gr_from_tensor, inputs=inp_file, outputs=out_tensor_img)

# Mount Gradio at /gradio
app = gr.mount_gradio_app(app, demo, path="/")

if __name__ == "__main__":
    server_name = "127.0.0.1"
    server_port = 7860
    share_token = secrets.token_urlsafe(32)
    share_url = networking.setup_tunnel(
        local_host=server_name,
        local_port=server_port,
        share_token=share_token,
        share_server_address=None,
        share_server_tls_certificate=None,
    )
    print(f"Share URL: {share_url}")
    uvicorn.run(app, host="127.0.0.1", port=7860)
