
import os
import sys
import torch
import shutil
import random
import numpy as np
import transformers
from PIL import Image
from tqdm import trange
from scipy.ndimage import zoom
from safetensors import safe_open
from safetensors.torch import save_file
from peft import PeftModel, PeftConfig

def select_colors(x, colors):
    return np.stack([x[..., 'rgb'.index(c)] for c in colors], axis=-1)

def compute_histogram(image, bins=64, colors='rg', resolution=1024):
    if isinstance(image, str):
        image = np.array(Image.open(image)).astype(np.uint8)
        ratio = random.uniform(0.5, 2.0)
        image = zoom(image, (ratio, ratio, 1), order=0)
    # Quantize input image
    quantized_image = (image // (256 // bins)).astype(int)
    quantized_image = select_colors(quantized_image, colors)
    # Compute histogram
    full_range = [(0, bins)] * len(colors)
    full_hist = np.histogramdd(quantized_image.reshape(-1, len(colors)),
                                bins=bins, range=full_range)[0].flatten()
    indices = np.arange((bins ** len(colors)))
    assert indices.shape[0] == 4096
    hist = full_hist[indices]
    hist = hist / hist.sum() * (resolution ** 2)
    return hist, ratio

num_tokens = int(sys.argv[1])
base_ckpt = '../length/ckpt/tokens2048_x/'
peft_model_id = f'./recon_pt_quant_tok{num_tokens}_rand'
if os.path.exists(peft_model_id):
    shutil.rmtree(peft_model_id)
shutil.copytree(src=base_ckpt, dst=peft_model_id)

for i in trange(300):

    ref_img = f'../gen/images_quant/tok{num_tokens}/{i}/final.png'
    ref_ckpt = f'../exp/ckpt_test/tok{num_tokens}/{i}.safetensors'

    if (not os.path.exists(ref_img)) or (not os.path.exists(ref_ckpt)):
        continue

    with safe_open(ref_ckpt, framework="pt", device="cpu") as f:
        ref_emb = f.get_tensor('prompt_embeddings')
    hist, ratio = compute_histogram(ref_img)
    pixels = torch.from_numpy(hist)
    hist = pixels / pixels.sum()

    s_values = torch.arange(4000, 6000, dtype=torch.float32)
    norms = torch.log(s_values.unsqueeze(1) * hist).norm(dim=1)
    min_index = torch.argmin(norms)
    optimal_s = s_values[min_index].item()
    # print(f'Optimal s: {optimal_s}, with norm: {norms[min_index].item():.4f}')

    emb = (optimal_s * hist).log().view(1, -1)
    save_file({'prompt_embeddings': emb}, os.path.join(peft_model_id, 'adapter_model.safetensors'))
    # print('ref_emb:', ref_emb[0][:10])
    # print('recon_emb:', emb[0][:10])
    # print('avg_diff:', (ref_emb - emb).abs().mean().item())
    avg_diff = (ref_emb - emb).abs().mean().item()

    with open(f'../exp/data_test/tok{num_tokens}/{i}.txt', 'r') as f:
        target = f.read()

    config = PeftConfig.from_pretrained(peft_model_id)
    pipeline = transformers.pipeline("text-generation", model=config.base_model_name_or_path, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")
    model = pipeline.model
    tokenizer = pipeline.tokenizer
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model = PeftModel.from_pretrained(model, peft_model_id)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    model.eval()
    early_stop = False
    with torch.no_grad():
        valid_ids = tokenizer("<|begin_of_text|>", return_tensors="pt").input_ids.to(device)
        valid_mask = torch.ones_like(valid_ids).to(device)
        for j in range(num_tokens+5):
            outputs = model(valid_ids, attention_mask=valid_mask)
            next_token_logits = outputs.logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1)
            valid_ids = torch.cat([valid_ids, next_token.unsqueeze(-1)], dim=-1)
            valid_mask = torch.ones_like(valid_ids).to(device)
            if (j > 0) and(j % 16 == 0):
                curr_text = tokenizer.decode(valid_ids[0], skip_special_tokens=True)
                if len(curr_text) < len(target):
                    early_stop = (curr_text not in target)
                if early_stop:
                    break
        if early_stop:
            result = 'fail'
        else:
            generated_text = tokenizer.decode(valid_ids[0], skip_special_tokens=True)
            result = 'success' if (target in generated_text) else 'fail'
            # print(f"Generated text: {generated_text}")
        
        log_path = f'../gen/images_quant/tok{num_tokens}/{i}/recon_rand.txt'
        with open(log_path, 'w') as f:
            info = f'{result}\n{optimal_s}\n{avg_diff}\n{ratio}'
            f.write(info)
        info = info.replace('\n', ' ')
        print(f'[tok{num_tokens}] idx {i} | {info}')
