import os
import json
import torch
import multiprocessing as mp
from tqdm import tqdm
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import AutoProcessor, AutoModelForVision2Seq
from helper import generate_prompt_for_baseline
import warnings
from transformers.utils import logging as hf_logging

hf_logging.set_verbosity_error()
warnings.filterwarnings("ignore")
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

class TimeSeriesDataset(Dataset):
    def __init__(self, names, data_dir, dataset_name):
        self.names = names
        self.data_dir = data_dir
        self.dataset_name = dataset_name

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]
        meta_path = os.path.join(self.data_dir, "metadata", f"{name}.json")
        ts_path = os.path.join(self.data_dir, "time series", f"{name}.txt")
        img_path = os.path.join(self.data_dir, "plots", f"{name}.jpeg")

        with open(meta_path) as f:
            metadata = json.load(f)
        with open(ts_path) as f:
            ts = ", ".join([line.strip() for line in f if line.strip()])

        return name, metadata, ts, img_path


def collate_fn(batch):
    return list(zip(*batch))

def process_chunk(gpu_id, chunk, data_dir, output_dir, dataset_name):
    torch.cuda.set_device(gpu_id)
    device = f"cuda:{gpu_id}"

    processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b")
    model = AutoModelForVision2Seq.from_pretrained(
        "HuggingFaceM4/idefics2-8b",
        device_map=f"cuda:{gpu_id}",
        torch_dtype=torch.bfloat16,
    ).eval()

    dataset = TimeSeriesDataset(chunk, data_dir, dataset_name)
    loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

    for names, metadatas, tss, img_paths in tqdm(loader, desc=f"[GPU {gpu_id}]", position=gpu_id):
        try:
            name = names[0]
            metadata = metadatas[0]
            ts = tss[0]
            img_path = img_paths[0]

            if not os.path.exists(img_path):
                tqdm.write(f"[GPU {gpu_id}] Skipping missing image: {img_path}")
                continue

            prompt = generate_prompt_for_baseline(dataset_name, metadata, ts)

            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image"},
                        {"type": "text", "text": prompt.strip()}
                    ]
                }
            ]
            chat_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
            image = Image.open(img_path).convert("RGB")

            inputs = processor(text=chat_prompt, images=[image], return_tensors="pt").to(device)

            outputs = model.generate(
                **inputs,
                max_new_tokens=300,
                do_sample=False
            )

            caption = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
            if "Assistant:" in caption:
                caption = caption.split("Assistant:")[-1].strip()

            os.makedirs(output_dir, exist_ok=True)
            with open(os.path.join(output_dir, f"{name}.txt"), "w") as f:
                f.write(caption)

        except Exception as e:
            tqdm.write(f"[GPU {gpu_id}] Failed: {name} — {e}")


if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)

    dataset_names = [
        "air quality", "crime", "border crossing", "demography", "road injuries",
        "covid", "co2", "diet", "online retail", "walmart", "agriculture"
    ]

    data_dir = "/home/ubuntu/projects/new_data/test"
    output_base = "/home/ubuntu/projects/outputs/idefics2_8b"

    for dataset_name in dataset_names:
        captions_dir = os.path.join(data_dir, "captions")
        names = sorted([
            f.replace(".txt", "")
            for f in os.listdir(captions_dir)
            if f.startswith(f"{dataset_name}_")
        ])

        print(f"[INFO] Found {len(names)} samples for dataset: {dataset_name}")
        if not names:
            continue

        gpu_ids = [1]  
        num_gpus = len(gpu_ids)
        chunks = [names[i::num_gpus] for i in range(num_gpus)]

        processes = []
        for i, gpu_id in enumerate(gpu_ids):
            p = mp.Process(target=process_chunk, args=(gpu_id, chunks[i], data_dir, output_base, dataset_name))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()