import torch
import numpy as np
import pandas as pd

import sys
import os

sys.path.insert(
    0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))
)

from model.llm.OpenTSLMFlamingo import OpenTSLMFlamingo
from prompt.text_prompt import TextPrompt
from prompt.text_time_series_prompt import TextTimeSeriesPrompt
from prompt.full_prompt import FullPrompt
from time_series_datasets.util import (
    extend_time_series_to_match_patch_size_and_aggregate,
)

# 1. Load the model
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using device: {device}")
model = OpenTSLMFlamingo(
    device=device,
    llm_id="google/gemma-2b",  # or whatever you used for training
)

model.load_from_file("../models/best_model.pt")


# 2. Load the M4 series-M42150 from CSV
csv_path = "../data/M4TimeSeriesCaptionDataset/m4_series_Monthly.csv"
df = pd.read_csv(csv_path)
row = df[df["id"] == "series-M42150"].iloc[0]
series_str = row["series"]
# Remove brackets and split
series = [
    float(x)
    for x in series_str.strip("[]").replace("\n", "").replace(" ", "").split(",")
    if x
]
series = np.array(series, dtype=np.float32)
mean = series.mean()
std = series.std()
normalized_series = (series - mean) / std if std > 0 else series - mean

# 3. Build the prompt
pre_prompt = TextPrompt("You are an expert in time series analysis.")
ts_text = f"This is the time series, it has mean {mean:.4f} and std {std:.4f}:"
ts_prompt = TextTimeSeriesPrompt(ts_text, normalized_series.tolist())
post_prompt = TextPrompt(
    "Please generate a detailed caption for this time-series, describing it as accurately as possible."
)

# 4. Build the batch (list of dicts)
prompt = FullPrompt(
    pre_prompt,
    [ts_prompt],
    post_prompt,
)

print(model.eval_prompt(prompt))
