import os
import json
import torch
import multiprocessing as mp
from tqdm import tqdm
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import AutoProcessor, AutoModelForImageTextToText
from helper import generate_prompt_for_baseline


def collate_fn(batch):
    return list(zip(*batch))

class TimeSeriesDataset(Dataset):
    def __init__(self, names, data_dir, dataset_name, transform):
        self.names = names
        self.data_dir = data_dir
        self.dataset_name = dataset_name
        self.transform = transform

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]
        img_path = os.path.join(self.data_dir, "plots", f"{name}.jpeg")
        meta_path = os.path.join(self.data_dir, "metadata", f"{name}.json")
        ts_path = os.path.join(self.data_dir, "time series", f"{name}.txt")

        image = Image.open(img_path).convert("RGB") 
        image_tensor = self.transform(image)

        with open(meta_path) as f:
            metadata = json.load(f)
        with open(ts_path) as f:
            ts = ", ".join([line.strip() for line in f if line.strip()])

        return name, image, metadata, ts


def process_chunk(gpu_id, chunk, data_dir, output_dir, dataset_name):
    torch.cuda.set_device(gpu_id)
    device = f"cuda:{gpu_id}"

    processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
    model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM-Instruct").to(device).eval()

    transform = transforms.Compose([
        transforms.Resize((448, 448)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])

    dataset = TimeSeriesDataset(chunk, data_dir, dataset_name, transform)

    loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        persistent_workers=True,
        collate_fn=collate_fn
    )

    for names, images, metadatas, tss in tqdm(loader, desc=f"[GPU {gpu_id}]", position=gpu_id):
        try:
            name = names[0]
            # image = images[0] 
            metadata = metadatas[0]
            ts = tss[0]

            prompt = generate_prompt_for_baseline(dataset_name, metadata, ts)
            messages = [{
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt.strip()}
                ]
            }]

            inputs = processor.apply_chat_template(
                messages,
                add_generation_prompt=True,
                tokenize=True,
                return_dict=True,
                return_tensors="pt"
            ).to(device, dtype=torch.bfloat16)

            generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=300)
            caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

            if "Assistant:" in caption:
                caption = caption.split("Assistant:")[-1].strip()

            with open(os.path.join(output_dir, f"{name}.txt"), "w") as f:
                f.write(caption.strip())

        except Exception as e:
            tqdm.write(f"[GPU {gpu_id}] Failed: {name} — {e}")


if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)

    dataset_names = [
        "air quality", "crime", "border crossing", "demography", "road injuries",
        "covid", "co2", "diet", "online retail", "walmart", "agriculture"
    ]

    data_dir = "/home/ubuntu/projects/new_data/test"
    output_base = "/home/ubuntu/projects/outputs/smolvlm_text"

    for dataset_name in dataset_names:
        captions_dir = os.path.join(data_dir, "captions")
        names = sorted([
            f.replace(".txt", "")
            for f in os.listdir(captions_dir)
            if f.startswith(f"{dataset_name}_")
        ])

        print(f"[INFO] Found {len(names)} samples for dataset: {dataset_name}")
        if not names:
            continue 

        gpu_ids = [1]
        num_gpus = len(gpu_ids)
        chunks = [names[i::num_gpus] for i in range(num_gpus)]

        processes = []
        for i, gpu_id in enumerate(gpu_ids):
            p = mp.Process(target=process_chunk, args=(gpu_id, chunks[i], data_dir, output_base, dataset_name))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()