import numpy as np
from thinktime.utils.api_utils import AskGPTAPI
import json
import os
import re
from loguru import logger
from thinktime.ts_generator.generate import generate_random_attributes, generate_time_series, generate_controlled_attributes, attribute_to_text, all_attribute_set
from thinktime.utils.encoding_utils import timeseries_encoding, timeseries_to_list
from thinktime.utils.attribute_utils import metric_to_controlled_attributes
from thinktime.grpo.utils.mts_reason_tasks import task_to_function
from tqdm import tqdm
import yaml
import random


# Config
# AIOPS NAB UCR
TOTAL_CNT = 100
SEQ_LEN = 256
RANDOM_SAMPLE_SIZE = (20, 20)

ENCODING_METHOD = yaml.safe_load(open("config/datagen_config.yaml"))['encoding_method']
DISABLE_METRIC_CONFIG = yaml.safe_load(open("config/datagen_config.yaml"))["disable_metric_config"]
DISABLE_EXTREME_LENGTHS = yaml.safe_load(open("config/datagen_config.yaml"))["disable_extreme_lengths"]
OUTPUT_PATH = f"data/mts_reason_rlvr_{TOTAL_CNT}_{RANDOM_SAMPLE_SIZE[0]}_{SEQ_LEN}.jsonl"
BASE_URL = "[OPENAI_BASE_URL]"
API_KEY = "[OPENAI_API_KEY]"
MODEL = "[OPENAI_API_MODEL]"

# All Config for TS Attributes (type & probability)
metric_config = json.load(open('config/metric_set.json', 'rt'))

# Load dataset
os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)

# Initialize API
client = AskGPTAPI(api_key=API_KEY, base_url=BASE_URL, model=MODEL, num_workers=64, timeout=240)


def create_qa_prompt(series_ids, timeseries_list, question):
    ts_descriptions = []
    for i, (series_id, ts) in enumerate(zip(series_ids, timeseries_list)):
        ts_text = ",".join([f"{x:.2f}" for x in ts])
        ts_descriptions.append(f" - {series_id} is a time series of length {len(ts)}: [{ts_text}]")
    
    mts_description = ";\n".join(ts_descriptions)
    return f"I have {len(timeseries_list)} time series:\n{mts_description}.\n{question}"

def valid_answer(generated_answer, expected_answer):
    """
    Check if the generated answer matches the expected answer
    """
    try:
        # For yes/no questions
        if expected_answer.lower() in ['yes', 'no']:
            return generated_answer.lower().strip() == expected_answer.lower()
        
        # For choice questions (A, B, C, D)
        if expected_answer.upper() in ['A', 'B', 'C', 'D']:
            return generated_answer.upper().strip() == expected_answer.upper()
        
        # For numerical answers
        try:
            generated_num = float(generated_answer)
            expected_num = float(expected_answer)
            if abs(expected_num) > 1e-2:
                return abs(generated_num - expected_num) / abs(expected_num) < 0.1  # 10% tolerance
            else:
                return abs(generated_num - expected_num) < 1e-2
        except ValueError:
            pass
        
        # For other string answers
        return generated_answer.lower().strip() == expected_answer.lower().strip()
    except Exception:
        return False

def generate_timeseries():
    # Determine sequence length
    if SEQ_LEN is None:
        p = random.random()
        if p > 0.4:
            current_seq_len = 256
        elif p > 0.1 or DISABLE_EXTREME_LENGTHS:
            current_seq_len = random.randint(64, 1024)
        elif p > 0.05:
            current_seq_len = random.randint(5, 64)
        else:
            current_seq_len = random.randint(1024, 4096)
    else:
        current_seq_len = SEQ_LEN

    # Random choose a category and metric name
    sample = random.choice(list(metric_config))
    category = sample['category']
    metric = random.choice(sample['metrics'])

    # Choose a metric and generate
    # Generate attribute_pool and time series
    if DISABLE_METRIC_CONFIG:
        attribute_pool = generate_random_attributes(all_attribute_set['overall_attribute'], all_attribute_set['change'], seq_len=current_seq_len)
    else:
        attribute_pool = generate_controlled_attributes(metric_to_controlled_attributes(metric), seq_len=current_seq_len)

    attribute_pool['metric_name'] = metric
    attribute_pool['situation'] = category

    timeseries, attribute_pool = generate_time_series(attribute_pool, current_seq_len)

    # Scalar
    scaled_timeseries, cur_ts_prompt, _ = timeseries_encoding(timeseries, ENCODING_METHOD)

    # Generate QA
    prompt = f"This is a metric called {metric} collected from {category} with length of {current_seq_len}. "
    prompt += "The features of the given time series are as follows: "
    prompt += attribute_to_text(
        timeseries,
        attribute_pool,
        include_attributes=['length', 'trend', 'periodicity', 'frequency', 'noise', 'local', 'statistic'],
        generate_values=False
    )

    return {
        'series_id': metric,
        'data': timeseries_to_list(scaled_timeseries),
        'caption': prompt,
        'attribute': attribute_pool
    }

def generate_timeseries_dataset(total_cnt):
    dataset = []
    for _ in tqdm(range(total_cnt), desc="Generating time series"):
        ts_data = generate_timeseries()
        dataset.append(ts_data)
    return dataset

def create_multi_series_samples(dataset):
    """Create samples by randomly selecting 2-5 series from the dataset"""
    multi_series_samples = []
    
    for _ in range(len(dataset)):
        # Randomly select 2-5 series
        sample_size = random.randint(RANDOM_SAMPLE_SIZE[0], RANDOM_SAMPLE_SIZE[1])
        selected_indices = random.sample(range(len(dataset)), sample_size)
        
        selected_items = [dataset[i] for i in selected_indices]
        original_metric_names = [item['series_id'] for item in selected_items]
        captions = [item['caption'].rstrip('</think>') for item in selected_items]
        timeseries_list = [item['data'] for item in selected_items]
        attributes = [item['attribute'] for item in selected_items]
        
        # Handle duplicate metric names by adding suffixes
        metric_names = []
        name_counts = {}
        
        for name in original_metric_names:
            if name not in name_counts:
                name_counts[name] = 0
                metric_names.append(name)
            else:
                name_counts[name] += 1
                metric_names.append(f"{name}-{name_counts[name]}")
        
        multi_series_samples.append({
            'indices': selected_indices,
            'metric_names': metric_names,
            'captions': captions,
            'timeseries_list': timeseries_list,
            'attributes': attributes
        })
    
    return multi_series_samples

def main():
    # Create tmp directory if it doesn't exist
    os.makedirs("tmp", exist_ok=True)
    
    # Step 0: Generate base time series dataset
    logger.info("Step 0: Generating base time series dataset...")
    dataset = generate_timeseries_dataset(TOTAL_CNT)
    
    # Step 0.5: Create multi-series samples
    logger.info("Step 0.5: Creating multi-series samples...")
    multi_series_samples = create_multi_series_samples(dataset)
    logger.info(f"Created {len(multi_series_samples)} multi-series samples")
    
    # Step 1: Batch generate questions
    logger.info("Step 1: Generating questions in batch...")

    prompts, answer_formats = [], []
    for sample in multi_series_samples:
        # Randomly choose a task
        task = random.choice(['pattern_recognition', 'numerical_judgement', 'calculation', 'causal'])
        cur_prompt, cur_answer_format = task_to_function[task](sample['metric_names'], sample['captions'])
        sample['ability_type'] = f"mts_reason_{task}"
        prompts.append(cur_prompt)
        answer_formats.append(cur_answer_format)

    # Save questions to tmp folder
    questions_responses = client.batch_ask_api(prompts, use_tqdm=True, thinking=True)
    questions_file = "tmp/generated_questions_mts.json"
    with open(questions_file, 'w') as f:
        json.dump(questions_responses, f, indent=2)

    logger.info(f"Generated questions saved to {questions_file}")

    # Step 2: Parse questions and create QA prompts
    logger.info("Step 2: Creating QA prompts...")
    qa_prompts = []
    qa_metadata = []

    for i, response in enumerate(questions_responses):
        try:
            questions_data = json.loads(response)
            if not isinstance(questions_data, list):
                continue
            
            sample = multi_series_samples[i]
            cur_answer_format = answer_formats[i]
            
            for q_data in questions_data:
                if 'question' in q_data and 'answer' in q_data:
                    q_data['question'] = q_data['question'] + cur_answer_format
                    
                    qa_prompt = create_qa_prompt(sample['metric_names'], sample['timeseries_list'], q_data['question'])
                    qa_prompts.append(qa_prompt)
                    qa_metadata.append({
                        'sample_idx': i,
                        'ts_indices': sample['indices'],
                        'series_ids': sample['metric_names'],
                        'timeseries_list': sample['timeseries_list'],
                        'captions': sample['captions'],
                        'question': q_data['question'],
                        'expected_answer': q_data['answer'].lower() if type(q_data['answer']) == str else str(q_data['answer']),
                        'qa_prompt': qa_prompt,
                        'attributes': sample['attributes'],
                        'ability_type': sample['ability_type']
                    })
        except json.JSONDecodeError:
            logger.warning(f"Failed to parse JSON response for sample {i}: {response}")
            continue

    logger.info(f"Created {len(qa_prompts)} QA prompts")

    # Step 3: Verify the generated QAs 
    logger.info("Step 3: Verifying generated QAs...")
    
    def create_verification_prompt(question, captions, expected_answer):
        """
        Create a prompt to verify if:
        1. The question doesn't leak time series attributes
        2. The expected answer can be verified from the captions and is accurate
        """
        captions_text = "\n".join([f"Series {i+1}: {caption}" for i, caption in enumerate(captions)])
        
        verification_prompt = f"""You are a quality control expert for multivariate time series question-answer pairs. Your task is to verify two critical aspects:

**Task 1: Attribute Leakage Check**
Check if the question reveals or hints at any time series attributes/features such as:
- Trend information (upward, downward, stationary)
- Seasonality/periodicity details
- Noise levels
- Local fluctuation patterns (spikes, dips)
- Statistical values (mean, variance, amplitude)
- Specific numerical values from the time series
- Cross-series relationships or comparisons

**Task 2: Answer Accuracy Verification**
Verify if the expected answer can be confirmed as correct based on the caption information for all series.

**Given Information:**
- **Question:** {question}
- **Captions (ground truth attributes):** 
{captions_text}
- **Expected Answer:** {expected_answer}

**Verification Criteria:**
1. **Attribute Leakage:** The question should NOT contain explicit descriptions of time series features. It should describe scenarios or patterns without revealing the ground truth attributes.
2. **Answer Accuracy:** The expected answer should be verifiable and correct based on the caption information. Check if someone could logically arrive at this answer using only the caption data.

**Output Format:**
Provide your assessment in the following JSON format:
```json
{{
    "attribute_leakage": false,  // true if question leaks attributes, false if clean
    "answer_accurate": true,     // true if expected answer is correct based on captions, false if incorrect
    "verification_passed": true, // true only if both checks pass
    "explanation": "Brief explanation of your assessment"
}}
```

Please be strict in your evaluation. Even subtle hints about time series characteristics should be flagged as attribute leakage."""

        return verification_prompt
    
    def extract_verification_result(response):
        """Extract verification result from API response"""
        try:
            # Try to extract JSON from response
            import json
            # Find JSON content between ```json and ```
            start_idx = response.find('```json')
            if start_idx != -1:
                start_idx += 7  # len('```json')
                end_idx = response.find('```', start_idx)
                if end_idx != -1:
                    json_str = response[start_idx:end_idx].strip()
                    result = json.loads(json_str)
                    return result
            
            # If no JSON block found, try to parse the entire response
            result = json.loads(response.strip())
            return result
        except:
            # If parsing fails, assume verification failed
            return {
                "attribute_leakage": True,
                "answer_accurate": False, 
                "verification_passed": False,
                "explanation": "Failed to parse verification response"
            }
    
    verified_qa = []
    verification_prompts = []
    
    # Create verification prompts for all QAs
    for meta in qa_metadata:
        verification_prompt = create_verification_prompt(
            meta['question'], 
            meta['captions'], 
            meta['expected_answer']
        )
        verification_prompts.append(verification_prompt)
    
    # Batch call API for verification
    logger.info(f"Sending {len(verification_prompts)} verification prompts to API...")
    verification_responses = client.batch_ask_api(verification_prompts, use_tqdm=True)
    
    # Process verification results
    for i, (meta, verification_response) in enumerate(zip(qa_metadata, verification_responses)):
            
        verification_result = extract_verification_result(verification_response)
        
        # Only include QAs that pass verification
        if verification_result.get("verification_passed", False):
            qa_item = {
                'sample_idx': meta['sample_idx'],
                'ts_indices': meta['ts_indices'],
                'timeseries': meta['timeseries_list'],
                'captions': meta['captions'],
                'question': meta['question'],
                'answer': meta['expected_answer'],
                'cols': meta['series_ids'],
                'attributes': meta['attributes'],
                'verification_result': verification_result,
                'ability_type': meta['ability_type']
            }
            verified_qa.append(qa_item)
        else:
            # logger.debug(f"QA failed verification: {verification_result.get('explanation', 'No explanation')}")
            ...
    
    logger.info(f"Verification complete: {len(verified_qa)} out of {len(qa_metadata)} QAs passed verification")
    
    # Step 4: Save verified QAs to output file
    logger.info("Step 4: Saving verified QAs...")
    with open(OUTPUT_PATH, 'w') as f:
        for item in verified_qa:
            series_desc = ", ".join([f"{col} of length {len(ts)}: <ts><ts/>" for col, ts in zip(item['cols'], item['timeseries'])])
            cur_question = f"You are a time series analysis expert. These time series are: {series_desc}" + item['question']
            result = {
                "input": cur_question,
                "output": json.dumps({
                    "answer": item['answer'],
                    "attribute": item['answer'],
                    "ability_type": item['ability_type'],
                    "cols": item['cols']
                }),
                "type": f"mts_reason_{item['ability_type']}",
                "timeseries": item['timeseries'],
            }
            json.dump(result, f)
            f.write('\n')
    
    logger.success(f"Saved {len(verified_qa)} verified QA pairs to {OUTPUT_PATH}")


if __name__ == '__main__':
    main()
