import sys

import pandas as pd
import numpy as np
import torch
from sipbuild.generator import outputs

from transformers import AutoModelForCausalLM, AutoTokenizer


def clean_text(input_text):
    """Removes specific lines from a block of text."""

    lines_to_remove = ["Cutting Knowledge Date:", "Today Date:"]

    # Split the text into individual lines
    lines = input_text.splitlines()

    # Keep only the lines that do not start with any of the phrases to remove
    cleaned_lines = [
        line
        for line in lines
        if not any(line.strip().startswith(prefix) for prefix in lines_to_remove)
    ]

    # Join the cleaned lines back into a single string
    return "\n".join(cleaned_lines)


def describe(in_data, data_name, tokenizer):
    input_template = "The values were value1, ..., valuen."
    in_data = in_data + 3
    in_data = np.clip(in_data, 0, 6)

    values_str = ", ".join([str(int(value * 100)) for value in in_data])
    in_prompt = input_template.replace("value1, ..., valuen", values_str)

    # structured_prompt = [
    #     {
    #         "role": "system",
    #         "content": "Extract features from the series to forecast future values",
    #     },
    #     {"role": "user", "content": in_prompt},
    # ]
    #
    # text = tokenizer.apply_chat_template(
    #     structured_prompt, add_generation_prompt=True, tokenize=False
    # )
    # text = clean_text(text)
    text = "Extract features from the series to forecast future values \n" + in_prompt
    return text


if __name__ == "__main__":
    # --- Create a Sample Time Series ---
    # This data is designed to look like the chart you provided:
    # An initial value, a dip, a positive trend, and increasing volatility.
    np.random.seed(0)
    time_index = pd.to_datetime(
        pd.date_range(start="2020-01-01", periods=100, freq="M")
    )

    # Base trend
    trend = np.linspace(20, 30, 100)

    # Seasonal component with increasing amplitude
    seasonality = [i / 20 * np.sin(i / 3) * 5 for i in range(100)]

    # Random noise
    noise = np.random.normal(0, 1.5, 100)

    # Combine components and add an initial dip
    data = trend + seasonality + noise
    data[0] = 30
    data[1:10] -= np.linspace(5, 0, 9)

    # Create pandas Series
    time_series = pd.Series(data, index=time_index, name="Sample Data")
    time_series = (time_series - time_series.mean()) / time_series.std()

    model_name = "meta-llama/Llama-3.2-1B-Instruct"
    llm = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="cuda",
    )
    llm.config.use_cache = False

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    # --- Generate and Print Descriptions ---
    print(time_series.values)
    desp = describe(time_series.values, "ETTh1", tokenizer)
    inputs = tokenizer(
        desp,
        return_tensors="pt",
        padding=True,
        truncation=True,
    ).to("cuda")
    print(desp)
    with torch.no_grad():
        outputs = llm(**inputs, output_hidden_states=True)

    last_hidden_state = outputs.hidden_states[-1]
    attn_mask = inputs["attention_mask"].unsqueeze(-1)
    pooled_batch = (last_hidden_state * attn_mask).sum(1) / attn_mask.sum(1).clamp(
        min=1
    )
    print(pooled_batch)
