import asyncio
import json
from logging.config import dictConfig
import numpy as np
import random

import re
import os
import numpy as np
import random
import json
import hydra
from tqdm import tqdm
import time
from typing import List
from pathlib import Path
import base64
import copy
import matplotlib.pyplot as plt
from captioning.llm_client import LLMClient, Provider
single_ts_captions = [
    "Write a paragraph that analyzes the time series, covering its local behaviors, noise levels, periodic structures, overall trend, frequency content, and any other characteristics you consider important.",
    "Create a detailed description of the time series in one paragraph, including its trend, frequency properties, periodicity, noise, local variations, and other relevant characteristics.",
    "Provide a paragraph summarizing the time series characteristics such as noise, periodic patterns, long-term trends, frequency behavior, local anomalies, and any other significant features.",
    "Compose a detailed caption describing the frequency characteristics, noise, trends, local variations, periodic structures, and any other meaningful patterns you observe in the time series.",
    "Craft a one-paragraph summary of the time series, noting local fluctuations, periodic behavior, frequency features, trend, noise content, and any other insights you find important.",
    "Generate a descriptive paragraph detailing the time series' key attributes, including frequency structure, noise patterns, trend direction, local features, periodic elements, and other notable aspects.",
    "Give a thorough one-paragraph explanation of the time series, addressing periodicity, noise, frequency components, trend, local variations, and other relevant characteristics.",
    "Write a narrative paragraph explaining the time series, focusing on noise, frequency characteristics, periodicity, localized structures, the overall trend, and other important features you identify.",
    "Summarize the time series in a paragraph, describing its fluctuations, recurring patterns, noise levels, frequency-domain features, trend direction, and any additional traits you find significant.",
    "Develop a paragraph that captures the key features of the time series, such as frequency traits, trend, noise, periodic components, local behaviors, and other characteristics worth noting.",
    "Provide a one-paragraph caption analyzing the time series data in terms of noise, trend, periodicity, local features, frequency-related behavior, and any additional characteristics of interest.",
    "Create a rich paragraph description of the time series, including its trend, local anomalies, periodic activity, noise artifacts, spectral content, and other important descriptive elements.",
    "Write a descriptive paragraph for the time series, highlighting frequency properties, trend behavior, periodic patterns, local structures, noise, and other characteristics you consider relevant.",
    "Generate a compact yet thorough paragraph explaining the time series in terms of periodicity, trend movement, noise level, frequency details, local dynamics, and any other key aspects.",
    "Construct a one-paragraph analysis of the time series by examining its local variations, noise, trend, periodic elements, frequency spectrum, and other notable features you deem important.",
    "Write a summary paragraph that discusses the time series' periodic features, trend behavior, local patterns, noise levels, frequency domain signals, and other characteristics worth mentioning.",
    "Create a detailed one-paragraph commentary on the time series that outlines its noise characteristics, periodicity, frequency content, trends, localized behaviors, and other useful insights.",
    "Prepare a paragraph-long description of the time series covering its trend, noise, frequency-related traits, local fluctuations, periodic structures, and any additional attributes of note.",
    "Offer a one-paragraph interpretation of the time series, highlighting its frequency features, periodic nature, local patterns, noise, trend line, and any other important characteristics you observe.",
    "Compose a detailed summary in one paragraph focusing on the time series' periodic behavior, frequency spectrum, localized fluctuations, overall trend, noise, and other relevant descriptive elements."
]

transformed_data = []

system_prompt = "You are a time series captioner."

API_KEY = None
API_BASE = None
MODEL_PATH = "gpt-4.1"
ctx_length = 4096
num_gpus = 8
gpu_per_model = 1
MULTIPROCESS = True
ENCODING_METHOD = 'sp'
TOTAL_CNT = 10
INPUT_FILES = [
    ('./data/MCQ/MCQ_1_TS.jsonl'),
]
OUTPUT_FILE = f'./data/MCQ/MCQ_1_TS_caption.jsonl'
EXP = f'vision_template'
HYRDA_CONFIG_PATH = "configs"

CACHE_DIR = "llm_cache" 
PROGRESS_FILE = Path(f"batch_progress_{TOTAL_CNT}_{ENCODING_METHOD}.json") 
NUM_SAMPLES = 1 
TEMPERATURE = 0.0
MAX_TOKENS = 2048 
TOP_P = 0.95
BATCH_SIZE = 256 
HYRDA_CONFIG_PATH = "configs"

def sp_encoding(timeseries: np.ndarray):
    mean = np.mean(timeseries)
    scaled_timeseries = timeseries - mean
    scale_factor = 1.0
    if np.any(np.abs(scaled_timeseries) >= 3.0):
        scale_factor = np.max(np.abs(scaled_timeseries)) / 3.0
        scaled_timeseries /= scale_factor

    prompt = f"[Value Offset: {-mean:.4f}|Value Scaling: {scale_factor:.4f}]<ts><ts/>"

    result_timeseries = np.stack([scaled_timeseries, np.ones_like(scaled_timeseries)], axis=-1).reshape(-1, 1)

    return result_timeseries, prompt, {'offset': float(-mean), 'scale_factor': float(scale_factor)}

def generate_image_from_timeseries(case_idx: int, timeseries: np.ndarray) -> str:
    timeseries = np.atleast_2d(timeseries)
    num_series = timeseries.shape[0]
    fig, axes = plt.subplots(num_series, 1, figsize=(6, num_series * 1.3), squeeze=False)

    for i in range(num_series):
        axes[i, 0].plot(timeseries[i], linewidth=2, color='blue')
        axes[i, 0].grid(True)

    plt.tight_layout()
    os.makedirs(f'exp/{EXP}/fig', exist_ok=True)
    plt.savefig(f"exp/{EXP}/fig/{case_idx}.jpg", format='JPG')
    plt.close()

    img_b64_str = base64.b64encode(open(f"exp/{EXP}/fig/{case_idx}.jpg", 'rb').read()).decode('utf-8')
    return img_b64_str

def get_caption_query_inputs(case_idx: int, timeseries: np.ndarray,  question: str) -> str:
    img_b64_str = generate_image_from_timeseries(case_idx, timeseries)
    img_type = "image/jpeg"

    messages = [
        {"role": "system", "content": system_prompt},
        {
            "role": "user",
            "content": [
                {"type": "text", "text": question},
                {"type": "image_url", "image_url": {"url": f"data:{img_type};base64,{img_b64_str}"}}
            ],
        }
    ]

    return messages

async def recaptioning():
    llm_client_instance = LLMClient(
            provider=Provider.OPENAI, 
            key=API_KEY, 
            cache_dir=CACHE_DIR
        )
    model_enum_instance = llm_client_instance.provider.models(MODEL_PATH)
    input_list = []

    print("Loading seed QA...")
    idx = 0
    for input_file in tqdm(INPUT_FILES, desc='Loading files'):
        with open(input_file, 'r', encoding='utf-8') as f_in:
            qa_dataset = [json.loads(line.rstrip()) for line in f_in]
        
    existed_descriptions = []
        
    for item in qa_dataset[:]:
        description = item['description'].strip()
        if description in existed_descriptions:
            continue
        existed_descriptions.append(description)
        caption_prompt = random.sample(single_ts_captions, 1)[0]
        original_timeseries = item['series']
        result_timeseries, time_series_prompt, statistics = sp_encoding(item['series'])
        result_timeseries = result_timeseries.tolist()
        input_list.append({
                    "description": description,
                    "time_series_prompt": time_series_prompt,
                    "caption_prompt": caption_prompt,
                    "original_timeseries": original_timeseries,
                    "transformed_timeseries": result_timeseries,
                    "idx": idx
                })
        idx += 1

    # Randomly shuffle input_list
    random.shuffle(input_list)

    print(f"{len(input_list)} seed QAs loaded from file.")

    # Run llm inference
    prompts_to_process_data = input_list[:]
    prepared_prompts = []
    for i, item in tqdm(enumerate(prompts_to_process_data), desc='recaptioning'):
        cur_item = copy.deepcopy(item)
        cur_prompt = cur_item["description"] + " " + cur_item["caption_prompt"]
        prompt = cur_prompt
        messages = get_caption_query_inputs(cur_item["idx"], [cur_item["original_timeseries"]], prompt)
        prepared_prompts.append(messages)
        

    all_generated_captions_nested: List[List[str]] = []
    all_errors_nested: List[List[Exception]] = []

    if prepared_prompts:
        print(f"Starting batch generation for {len(prepared_prompts)} prompts...")
        try:
            all_generated_captions_nested, all_errors_nested = await llm_client_instance.batch_generate_async(
                prompts=prepared_prompts,
                num_samples=1,
                model=model_enum_instance, 
                temperature=TEMPERATURE,
                max_tokens=MAX_TOKENS,
                top_p=TOP_P,
                ignore_cache_samples=False, 
                expand_n_completions=False, 
                batch_size=BATCH_SIZE,      
                progress_file=PROGRESS_FILE 
            )
            print("Batch generation finished.")

        except Exception as e:
            print(f"An error occurred during batch generation: {e}")
    new_dataset_entries = []
    print("Processing and filtering results...")
    if len(all_generated_captions_nested) != len(prepared_prompts):
         print(f"Warning: Number of results ({len(all_generated_captions_nested)}) does not match number of prompts ({len(prepared_prompts)}). Processing available results.")
    for i, generated_captions in enumerate(tqdm(all_generated_captions_nested, desc="Filtering Captions")):
        original_input_item = prompts_to_process_data[i]

        prompt_errors = all_errors_nested[i] if i < len(all_errors_nested) else []
        if prompt_errors:
            print(f"Note: Errors occurred for prompt index {i}: {[str(e) for e in prompt_errors]}")
        successful_filtered_captions = []
        for caption in generated_captions:
            if caption is None: continue
            successful_filtered_captions.append(caption)
        if len(successful_filtered_captions) > 0:
            new_caption = random.choice(successful_filtered_captions)
            new_data = {}
            new_data["input_prompt"] = original_input_item["description"] + "The time series is: " + original_input_item["time_series_prompt"] + "." + " " + original_input_item["caption_prompt"]
            new_data["output"] = new_caption
            new_data["timeseries"] = [original_input_item["transformed_timeseries"]]
            new_dataset_entries.append(new_data)
            with open(OUTPUT_FILE, 'a', encoding='utf-8') as f:
                json.dump(new_data, f, ensure_ascii=False)
                f.write('\n')

    end_time = time.time()
    print(f"-------------------------------------------")
    input_tokens, output_tokens = llm_client_instance.get_token_usage(model_enum_instance.value) #
    print(f"Approximate Token Usage for {model_enum_instance.value}: Input={input_tokens}, Output={output_tokens}")

    print(f"-------------------------------------------")
    print(f"Finished! File saved to {OUTPUT_FILE}.")

@hydra.main(version_base=None, config_path=HYRDA_CONFIG_PATH, config_name="default")
def main(cfg: dictConfig):
    asyncio.run(recaptioning())
    
if __name__ == '__main__':
    main()
