import os
import json
import torch
import multiprocessing as mp
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import AutoProcessor, MllamaForConditionalGeneration
from PIL import Image
from helper import generate_prompt_for_baseline
import re
import warnings
from transformers.utils import logging
logging.set_verbosity_error()
warnings.filterwarnings("ignore", message="The attention mask is not set*")
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
warnings.filterwarnings("ignore")
os.environ["CUDA_VISIBLE_DEVICES"] = "2,6,5"


class TimeSeriesDataset(Dataset):
    def __init__(self, names, data_dir, dataset_name):
        self.names = names
        self.data_dir = data_dir
        self.dataset_name = dataset_name

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]
        meta_path = os.path.join(self.data_dir, "metadata", f"{name}.json")
        ts_path = os.path.join(self.data_dir, "time series", f"{name}.txt")
        img_path = os.path.join(self.data_dir, "plots", f"{name}.jpeg") 

        with open(meta_path) as f:
            metadata = json.load(f)
        with open(ts_path) as f:
            ts = ", ".join([line.strip() for line in f if line.strip()])

        return name, metadata, ts, img_path


def collate_fn(batch):
    return list(zip(*batch))

def process_chunk(gpu_id, chunk, data_dir, output_dir, dataset_name):
    from transformers.utils import logging as hf_logging
    hf_logging.set_verbosity_error()
    warnings.filterwarnings("ignore", message="The attention mask is not set*")

    torch.cuda.set_device(gpu_id)
    device = f"cuda:{gpu_id}"

    model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
    token = ""

    processor = AutoProcessor.from_pretrained(model_id, token=token)
    model = MllamaForConditionalGeneration.from_pretrained(
        model_id,
        device_map=device,
        torch_dtype=torch.bfloat16,
        token=token
    ).eval()

    dataset = TimeSeriesDataset(chunk, data_dir, dataset_name)
    loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

    for names, metadatas, tss, img_paths in tqdm(loader, desc=f"[GPU {gpu_id}]", position=gpu_id):
        try:
            name = names[0]
            metadata = metadatas[0]
            ts = tss[0]
            img_path = img_paths[0]

            if not os.path.exists(img_path):
                tqdm.write(f"[GPU {gpu_id}] Skipping missing image: {img_path}")
                continue

            prompt = generate_prompt_for_baseline(dataset_name, metadata, ts)

            try:
                image = Image.open(img_path).convert("RGB")
            except Exception as e:
                tqdm.write(f"[GPU {gpu_id}] Error opening image {img_path}: {e}")
                continue
            
            messages = [
                {"role": "user", "content": [
                    {"type": "image"},
                    {"type": "text", "text": prompt.strip()}
                ]}
            ]
            input_text = processor.apply_chat_template(messages, add_generation_prompt=True)

            inputs = processor(
                images=image,
                text=input_text,
                add_special_tokens=False,
                return_tensors="pt"
            ).to(device)

            # print("image shape:", inputs["pixel_values"].shape)
            # print("text input shape:", inputs["input_ids"].shape)

            outputs = model.generate(
                **inputs,
                max_new_tokens=300
            )

            input_len = inputs["input_ids"].shape[-1]
            generated_ids = outputs[0][input_len:]
            caption = processor.decode(generated_ids, skip_special_tokens=True).strip()

            out_dir = os.path.join(output_dir)
            os.makedirs(out_dir, exist_ok=True)
            with open(os.path.join(out_dir, f"{name}.txt"), "w") as f:
                f.write(caption.strip())

        except Exception as e:
            tqdm.write(f"[GPU {gpu_id}] Failed: {name} — {e}")

if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)

    dataset_names = [
        "air quality", "diet", "border crossing", "demography", "road injuries", "crime",
        "covid", "co2",  "online retail", "walmart", "agriculture"
    ]

    data_dir = "/home/ubuntu/projects/new_data/test"
    output_base = "/home/ubuntu/projects/outputs/llama_vlm_11b_finetune"

    for dataset_name in dataset_names:
        captions_dir = os.path.join(data_dir, "captions")
        names = sorted([
            f.replace(".txt", "")
            for f in os.listdir(captions_dir)
            if f.startswith(f"{dataset_name}_")
        ])

        print(f"[INFO] Found {len(names)} samples for dataset: {dataset_name}")
        if not names:
            continue

        gpu_ids = [0,1,2]
        num_gpus = len(gpu_ids)
        chunks = [names[i::num_gpus] for i in range(num_gpus)]

        processes = []
        for i, gpu_id in enumerate(gpu_ids):
            p = mp.Process(target=process_chunk, args=(gpu_id, chunks[i], data_dir, output_base, dataset_name))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()