import os
import json
import torch
import multiprocessing as mp
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from PIL import Image
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from helper import generate_prompt_for_baseline
import warnings

warnings.filterwarnings("ignore", message="FlashAttention2 is not installed")
warnings.filterwarnings("ignore", message="InternLM2ForCausalLM has generative capabilities")
warnings.filterwarnings("ignore")

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    return T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ])

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=True):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height
    target_ratios = sorted(
        {(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1)
         if 1 <= i * j <= max_num},
        key=lambda x: x[0] * x[1]
    )
    target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image(image_file, input_size=448, max_num=12):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    return torch.stack(pixel_values)

class TimeSeriesDataset(Dataset):
    def __init__(self, names, data_dir, dataset_name, image_size=448, max_num=12):
        self.names = names
        self.data_dir = data_dir
        self.dataset_name = dataset_name
        self.image_size = image_size
        self.max_num = max_num

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]
        meta_path = os.path.join(self.data_dir, "metadata", f"{name}.json")
        ts_path = os.path.join(self.data_dir, "time series", f"{name}.txt")
        img_path = os.path.join(self.data_dir, "plots", f"{name}.jpeg")

        with open(meta_path) as f:
            metadata = json.load(f)
        with open(ts_path) as f:
            ts = ", ".join([line.strip() for line in f if line.strip()])

        pixel_values = load_image(img_path, input_size=self.image_size, max_num=self.max_num)
        return name, metadata, ts, pixel_values

def collate_fn(batch):
    return list(zip(*batch))

def process_chunk(gpu_id, chunk, data_dir, output_dir, dataset_name):
    torch.cuda.set_device(gpu_id)
    device = f"cuda:{gpu_id}"
    model_path = "/home/ubuntu/projects/time_series_main/models/internvl2_5_2b"

    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(device).eval()

    dataset = TimeSeriesDataset(chunk, data_dir, dataset_name)
    loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

    generation_config = {"num_beams": 1, "max_new_tokens": 256, "do_sample": False}

    for names, metadatas, tss, pixel_values_list in tqdm(loader, desc=f"[GPU {gpu_id}]", position=gpu_id):
        try:
            name = names[0]
            metadata = metadatas[0]
            ts = tss[0]
            pixel_values = pixel_values_list[0].to(torch.float32).to(device)

            prompt = generate_prompt_for_baseline(dataset_name, metadata, ts)
            question = f"<image>\n{prompt.strip()}"

            caption = model.chat(tokenizer, pixel_values, question, generation_config).strip()

            os.makedirs(output_dir, exist_ok=True)
            with open(os.path.join(output_dir, f"{name}.txt"), "w") as f:
                f.write(caption)

        except Exception as e:
            tqdm.write(f"[GPU {gpu_id}] Failed {name} — {e}")

if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)

    dataset_names = [
        "air quality", "crime", "border crossing", "demography", "road injuries",
        "covid", "co2", "diet", "online retail", "walmart", "agriculture"
    ]

    data_dir = "/home/ubuntu/projects/time_series_main/data"
    output_base = "/home/ubuntu/projects/outputs/internvl_vl"

    for dataset_name in dataset_names:
        captions_dir = os.path.join(data_dir, "captions")
        names = sorted([
            f.replace(".txt", "")
            for f in os.listdir(captions_dir)
            if f.startswith(f"{dataset_name}_")
            and int(f.replace(".txt", "").split("_")[-1]) % 3 == 0
        ])

        print(f"[INFO] Found {len(names)} samples for dataset: {dataset_name}")
        if not names:
            continue

        gpu_ids = [0, 1]
        num_gpus = len(gpu_ids)
        chunks = [names[i::num_gpus] for i in range(num_gpus)]

        processes = []
        for i, gpu_id in enumerate(gpu_ids):
            output_dir = os.path.join(output_base, dataset_name)
            p = mp.Process(target=process_chunk, args=(gpu_id, chunks[i], data_dir, output_dir, dataset_name))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()
