import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,7"

import json
import torch
import multiprocessing as mp
from tqdm import tqdm
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig
from peft import PeftModel
from helper import generate_prompt_for_baseline

def collate_fn(batch):
    return list(zip(*batch))

class TimeSeriesDataset(Dataset):
    def __init__(self, names, data_dir, dataset_name, transform):
        self.names = names
        self.data_dir = data_dir
        self.dataset_name = dataset_name
        self.transform = transform

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]
        img_path = os.path.join(self.data_dir, "plots", f"{name}.jpeg")
        meta_path = os.path.join(self.data_dir, "metadata", f"{name}.json")
        ts_path = os.path.join(self.data_dir, "time series", f"{name}.txt")

        image = Image.open(img_path).convert("RGB")
        _ = self.transform(image)  

        with open(meta_path) as f:
            metadata = json.load(f)
        with open(ts_path) as f:
            ts = ", ".join([line.strip() for line in f if line.strip()])

        return name, image, metadata, ts


def process_chunk(gpu_id, chunk, data_dir, output_dir, dataset_name, base_model_path, lora_path):
    torch.cuda.set_device(gpu_id)
    device = f"cuda:{gpu_id}"

    processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True)
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_path,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        # _attn_implementation="flash_attention_2",
        _attn_implementation="eager"
    ).to(device)
    model = PeftModel.from_pretrained(base_model, lora_path).to(device).eval()
    generation_config = GenerationConfig.from_pretrained(base_model_path)

    transform = transforms.Compose([
        transforms.Resize((448, 448)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])

    dataset = TimeSeriesDataset(chunk, data_dir, dataset_name, transform)
    loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate_fn)

    for names, images, metadatas, tss in tqdm(loader, desc=f"[GPU {gpu_id}]", position=gpu_id):
        try:
            name = names[0]
            image = images[0]
            metadata = metadatas[0]
            ts = tss[0]

            question = generate_prompt_for_baseline(dataset_name, metadata, ts)
            chat = [{"role": "user", "content": f"<|image_1|>{question}"}]
            prompt = processor.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)

            inputs = processor(text=prompt, images=[image], return_tensors="pt")
            inputs = {k: v.to(device) for k, v in inputs.items() if v is not None}

            generate_ids = model.generate(
                **inputs,
                max_new_tokens=512,
                generation_config=generation_config,
            )

            caption = processor.batch_decode(
                generate_ids[:, inputs["input_ids"].shape[1]:],
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
            )[0]

            caption = caption.strip().replace("<|end|>", "").replace("<|assistant|>", "")

            with open(os.path.join(output_dir, f"{name}.txt"), "w") as f:
                f.write(caption)

        except Exception as e:
            tqdm.write(f"[GPU {gpu_id}] Failed: {name} — {e}")


if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)

    dataset_names = [
     "crime", "border crossing", "demography", "road injuries",
        "covid", "co2", "diet", "online retail", "walmart", "agriculture"
    ]

    data_dir = "/home/ubuntu/projects/new_data/test"
    output_base = "/home/ubuntu/projects/outputs/phi4_finetune"
    base_model_path = "microsoft/Phi-4-multimodal-instruct"
    lora_path = "/home/ubuntu/projects/time_series_main/phi4/train/phi4_captioning_output/checkpoint-998"

    for dataset_name in dataset_names:
        output_dir = output_base
        os.makedirs(output_dir, exist_ok=True)

        names = sorted([
            f.replace(".txt", "")
            for f in os.listdir(os.path.join(data_dir, "captions"))
            if f.startswith(f"{dataset_name}_")
        ])

        print(f"[INFO] Found {len(names)} samples for dataset: {dataset_name}")
        if not names:
            continue

        gpu_ids = [0,1,2]
        chunks = [names[i::len(gpu_ids)] for i in range(len(gpu_ids))]

        processes = []
        for i, gpu_id in enumerate(gpu_ids):
            p = mp.Process(target=process_chunk, args=(
                gpu_id, chunks[i], data_dir, output_dir, dataset_name,
                base_model_path, lora_path
            ))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()
