import os
import json
import torch
import multiprocessing as mp
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from dataclasses import dataclass, field
from transformers import AutoModelForCausalLM
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from helper import generate_prompt_for_baseline
import warnings
warnings.filterwarnings("ignore", message="FlashAttention2 is not installed")
warnings.filterwarnings("ignore", message="InternLM2ForCausalLM has generative capabilities")
warnings.filterwarnings("ignore")

# unpacking batches from dataloader
def collate_fn(batch):
    return list(zip(*batch))

# defining the dataset for loading
class TSData(Dataset):
    def __init__(self, names, data_dir, dataset_name):
        self.names = names  # list of sample file names (without extens)
        self.data_dir = data_dir
        self.dataset_name = dataset_name

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx] 
        meta_path = os.path.join(self.data_dir,"metadata", f"{name}.json")
        ts_path = os.path.join(self.data_dir,"time series", f"{name}.txt")

        with open(meta_path) as f:
            metadata = json.load(f)
        with open(ts_path) as f:
            ts = ", ".join([line.strip() for line in f if line.strip()])

        return name, metadata, ts

# worker process for multithreading in multiple GPUs
def process_chunk(gpu_id, chunk, data_dir, output_dir, dataset_name):
    torch.cuda.set_device(gpu_id)
    device = f"cuda:{gpu_id}"
    model_name_or_path = "/home/ubuntu/projects/time_series_main/models/InternVL2_5-8B"
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True).half().to(device).eval()

    dataset = TSData(chunk, data_dir, dataset_name)
    loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

    generation_config = { "num_beams": 1, "max_new_tokens": 256, "do_sample": False }

    for names, metadatas, tss in tqdm(loader, desc=f"[GPU {gpu_id}]", position=gpu_id):
        try:
            name = names[0]
            metadata = metadatas[0]
            ts = tss[0]

            # generate prompt
            prompt = generate_prompt_for_baseline(dataset_name, metadata, ts)
            question = prompt.strip()
            
            # run model inference
            response = model.chat(tokenizer, None, question, generation_config)
            caption = response.strip()

            with open(os.path.join(output_dir, f"{name}.txt"), "w") as f:
                f.write(caption)

        except Exception as e:
            tqdm.write(f"[GPU {gpu_id}] Failed {name} — {e}")

if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)

    dataset_names = [
        "air quality", "crime", "border crossing", "demography", "road injuries",
        "covid", "co2", "diet", "online retail", "walmart", "agriculture"
    ]

    data_dir = "/home/ubuntu/projects/new_data/test"
    output_base = "/home/ubuntu/projects/outputs1/internvl_7b_text"  

    for dataset_name in dataset_names:
        captions_dir = os.path.join(data_dir, "captions")
        names = sorted([
            f.replace(".txt", "")
            for f in os.listdir(captions_dir)
            if f.startswith(f"{dataset_name}_")
        ])

        print(f"[INFO] Found {len(names)} samples for dataset: {dataset_name}")
        if not names:
            continue  # skip empty datasets

        gpu_ids = [1, 2]
        num_gpus = len(gpu_ids)
        chunks = [names[i::num_gpus] for i in range(num_gpus)]

        # spawn a process per GPU for parallel generation
        processes = []
        for i, gpu_id in enumerate(gpu_ids):
            p = mp.Process(target=process_chunk, args=(gpu_id, chunks[i], data_dir, output_base, dataset_name))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()