import os
import json
import torch
import multiprocessing as mp
from tqdm import tqdm
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import (
    LlavaForConditionalGeneration,
    AutoTokenizer,
    CLIPImageProcessor,
    LlavaProcessor
)
from helper import generate_prompt_for_baseline
import warnings
from transformers.utils import logging as hf_logging

hf_logging.set_verbosity_error()
warnings.filterwarnings("ignore")
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

class TimeSeriesDataset(Dataset):
    def __init__(self, names, data_dir, dataset_name):
        self.names = names
        self.data_dir = data_dir
        self.dataset_name = dataset_name

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]
        meta_path = os.path.join(self.data_dir, "metadata", f"{name}.json")
        ts_path = os.path.join(self.data_dir, "time series", f"{name}.txt")
        img_path = os.path.join(self.data_dir, "plots", f"{name}.jpeg")

        with open(meta_path) as f:
            metadata = json.load(f)
        with open(ts_path) as f:
            ts = ", ".join([line.strip() for line in f if line.strip()])

        return name, metadata, ts, img_path

def collate_fn(batch):
    return list(zip(*batch))

def process_chunk(gpu_id, chunk, data_dir, output_dir, dataset_name):
    torch.cuda.set_device(gpu_id)
    device = f"cuda:{gpu_id}"

    tokenizer = AutoTokenizer.from_pretrained("liuhaotian/llava-v1.6-34b")
    image_processor = CLIPImageProcessor.from_pretrained("liuhaotian/llava-v1.6-34b")
    processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)

    model = LlavaForConditionalGeneration.from_pretrained(
        "liuhaotian/llava-v1.6-34b",
        device_map=f"cuda:{gpu_id}",
        torch_dtype=torch.bfloat16
    ).eval()

    dataset = TimeSeriesDataset(chunk, data_dir, dataset_name)
    loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

    for names, metadatas, tss, img_paths in tqdm(loader, desc=f"[GPU {gpu_id}]", position=gpu_id):
        try:
            name = names[0]
            metadata = metadatas[0]
            ts = tss[0]
            img_path = img_paths[0]

            if not os.path.exists(img_path):
                tqdm.write(f"[GPU {gpu_id}] Skipping missing image: {img_path}")
                continue

            prompt = generate_prompt_for_baseline(dataset_name, metadata, ts)
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image"},
                        {"type": "text", "text": prompt.strip()}
                    ]
                }
            ]

            chat_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
            image = Image.open(img_path).convert("RGB")

            inputs = processor(text=chat_prompt, images=[image], return_tensors="pt").to(device, torch.bfloat16)

            outputs = model.generate(
                **inputs,
                max_new_tokens=300,
                do_sample=False
            )

            caption = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
            if "Assistant:" in caption:
                caption = caption.split("Assistant:")[-1].strip()

            os.makedirs(output_dir, exist_ok=True)
            with open(os.path.join(output_dir, f"{name}.txt"), "w") as f:
                f.write(caption)

        except Exception as e:
            tqdm.write(f"[GPU {gpu_id}] Failed: {name} — {e}")

if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)

    dataset_names = [
        "air quality", "crime", "border crossing", "demography", "road injuries",
        "covid", "co2", "diet", "online retail", "walmart", "agriculture"
    ]

    data_dir = "/home/ubuntu/projects/new_data/test"
    output_base = "/home/ubuntu/projects/outputs/llava_34b"

    for dataset_name in dataset_names:
        captions_dir = os.path.join(data_dir, "captions")
        names = sorted([
            f.replace(".txt", "")
            for f in os.listdir(captions_dir)
            if f.startswith(f"{dataset_name}_")
        ])

        print(f"[INFO] Found {len(names)} samples for dataset: {dataset_name}")
        if not names:
            continue

        gpu_ids = [0]
        num_gpus = len(gpu_ids)
        chunks = [names[i::num_gpus] for i in range(num_gpus)]

        processes = []
        for i, gpu_id in enumerate(gpu_ids):
            p = mp.Process(target=process_chunk, args=(gpu_id, chunks[i], data_dir, output_base, dataset_name))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()