
import os
import sys
import torch
import shutil
import string
import random
import numpy as np
import transformers
from PIL import Image
from tqdm import trange
from scipy.ndimage import zoom
from safetensors import safe_open
from safetensors.torch import save_file
from peft import PeftModel, PeftConfig

num2rand = np.load('indices_num2rand.npy')

def select_colors(x, colors):
    return np.stack([x[..., 'rgb'.index(c)] for c in colors], axis=-1)

def compute_histogram(image, colors='rg', resolution=1024):
    if isinstance(image, str):
        image = np.array(Image.open(image))
        ratio = random.uniform(0.5, 2.0)
        image = zoom(image, (ratio, ratio, 1), order=0)
    # Select two of the color channels
    image = select_colors(image.astype(int), colors)
    image = (image[..., 0] << 8) + image[..., 1]
    image = num2rand[image].flatten()
    hist = np.bincount(image, minlength=65536)
    hist = hist.reshape(4096, -1).sum(axis=1)
    hist = hist / hist.sum() * (resolution ** 2)
    return hist, ratio

num_tokens = int(sys.argv[1])
base_ckpt = '../length/ckpt/tokens2048_x/'
random_string = ''.join(random.choice(string.ascii_lowercase) for _ in range(4))
peft_model_id = f'./recon_ckpt/bin_tok{num_tokens}_rand_{random_string}'
if os.path.exists(peft_model_id):
    shutil.rmtree(peft_model_id)
shutil.copytree(src=base_ckpt, dst=peft_model_id)

for i in trange(300):

    ref_img = f'../gen/images_bin/tok{num_tokens}/{i}/final.png'
    ref_ckpt = f'../exp/ckpt_test/tok{num_tokens}/{i}.safetensors'

    if (not os.path.exists(ref_img)) or (not os.path.exists(ref_ckpt)):
        continue

    with safe_open(ref_ckpt, framework="pt", device="cpu") as f:
        ref_emb = f.get_tensor('prompt_embeddings')
    hist, ratio = compute_histogram(ref_img)
    pixels = torch.from_numpy(hist)
    hist = pixels / pixels.sum()

    s_values = torch.arange(4000, 6000, dtype=torch.float32)
    norms = torch.log(s_values.unsqueeze(1) * hist).norm(dim=1)
    min_index = torch.argmin(norms)
    optimal_s = s_values[min_index].item()
    # print(f'Optimal s: {optimal_s}, with norm: {norms[min_index].item():.4f}')

    emb = (optimal_s * hist).log().view(1, -1)
    save_file({'prompt_embeddings': emb}, os.path.join(peft_model_id, 'adapter_model.safetensors'))
    # print('ref_emb:', ref_emb[0][:10])
    # print('recon_emb:', emb[0][:10])
    # print('avg_diff:', (ref_emb - emb).abs().mean().item())
    avg_diff = (ref_emb - emb).abs().mean().item()

    with open(f'../exp/data_test/tok{num_tokens}/{i}.txt', 'r') as f:
        target = f.read()

    config = PeftConfig.from_pretrained(peft_model_id)
    pipeline = transformers.pipeline("text-generation", model=config.base_model_name_or_path, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")
    model = pipeline.model
    tokenizer = pipeline.tokenizer
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model = PeftModel.from_pretrained(model, peft_model_id)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    model.eval()
    early_stop = False
    with torch.no_grad():
        valid_ids = tokenizer("<|begin_of_text|>", return_tensors="pt").input_ids.to(device)
        valid_mask = torch.ones_like(valid_ids).to(device)
        for j in trange(num_tokens+5):
            outputs = model(valid_ids, attention_mask=valid_mask)
            next_token_logits = outputs.logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1)
            valid_ids = torch.cat([valid_ids, next_token.unsqueeze(-1)], dim=-1)
            valid_mask = torch.ones_like(valid_ids).to(device)
            if (j > 0) and(j % 16 == 0):
                curr_text = tokenizer.decode(valid_ids[0], skip_special_tokens=True)
                if len(curr_text) < len(target):
                    early_stop = (curr_text not in target)
                if early_stop:
                    break
        if early_stop:
            result = 'fail'
        else:
            generated_text = tokenizer.decode(valid_ids[0], skip_special_tokens=True)
            result = 'success' if (target in generated_text) else 'fail'
            # print(f"Generated text: {generated_text}")
        
        log_path = f'../gen/images_bin/tok{num_tokens}/{i}/recon_rand.txt'
        with open(log_path, 'w') as f:
            info = f'{result}\n{optimal_s}\n{avg_diff}\n{ratio}'
            f.write(info)
        info = info.replace('\n', ' ')
        print(f'[tok{num_tokens}] idx {i} | {info}')
