import numpy as np
import json
import yaml
import os
import random
from thinktime.utils.encoding_utils import timeseries_encoding, timeseries_to_list
from tqdm import tqdm
import copy

# Config
TARGET_CNT = 20000
ENCODING_METHOD = yaml.safe_load(open("config/datagen_config.yaml"))['encoding_method']
OUTPUT_BASE_DIR = yaml.safe_load(open("config/datagen_config.yaml"))['data_output_dir']
SEQ_LEN = yaml.safe_load(open("config/datagen_config.yaml"))['seq_len']
OUTPUT_PATH = f'{OUTPUT_BASE_DIR}/ift_rlvr_{ENCODING_METHOD}.jsonl'
EVOL_LABEL_PATH = f'{OUTPUT_BASE_DIR}/evol_labels/ift_rlvr_{SEQ_LEN}_{ENCODING_METHOD}.json'
LABEL_FILES = [
    f'{OUTPUT_BASE_DIR}/labels/mts_local_llm_None_20000_{ENCODING_METHOD}.json',
    f'{OUTPUT_BASE_DIR}/labels/mts_shape_llm_None_20000_{ENCODING_METHOD}.json',
    f'{OUTPUT_BASE_DIR}/labels/uts_llm_None_20000_{ENCODING_METHOD}.json',
]
# Optional: per-filename sampling ratios (can be exact filename or substring keys)
FILE_SAMPLING_RATIOS = {
    "uts_llm": 0.4,
    "local_llm": 0.3,
    "shape_llm": 0.3
}
ALL_LOCAL_TYPES = {'increase after downward spike', 'increase after upward spike', 'upward spike', 'rapid decline followed by slow rise', 'slow rise followed by rapid decline', 'continuous upward spike', 'wide downward spike', 'slow decline followed by rapid rise', 'upward convex', 'shake', 'rapid rise followed by slow decline', 'sudden increase', 'downward spike', 'sudden decrease', 'continuous downward spike', 'decrease after upward spike', 'wide upward spike', 'decrease after downward spike', 'downward convex'}

# Return: question, answer
# L0: basic feature: STL shape + statistic
def generate_trend(sample):
    question = 'What is the trend of this time series? Please choose from ["steady", "decreasing", "increasing"] and describe the value trend change. Answer format: steady, the starting point value is around 32.10, and the trend change value from left to right is around 0.12.'
    answer = f"{sample['label']['trend']['type']}, the starting point value is around {sample['label']['trend']['start']:.2f}, and the trend change value from left to right is around {sample['label']['trend']['amplitude']:.2f}."
    if sample['label']['trend']['type'] == 'multiple':
        raise NotImplementedError("ift not implemented for multiple trend")
    features = {
        'type': sample['label']['trend']['type'],
        'start': sample['label']['trend']['start'],
        'amplitude': sample['label']['trend']['amplitude']
    }
    ability_type = 'trend'
    fields = {"trend": [0]}
    return question, answer, features, ability_type, fields

def generate_season(sample):
    question = 'What is the periodicity of this time series? Please choose from ["no periodic fluctuation", "periodic fluctuation"]. If there is periodic fluctuation, describe the fluctuation frequency and amplitude. Answer format: periodic fluctuation, each period is around 20.58 points, and the amplitude of the periodic fluctuation is around 31.51.'
    if 'no' in sample['label']['seasonal']['type']:
        answer = "no periodic fluctuation"
        amplitude = 0.0
    else:
        amplitude = sample['label']['seasonal']['amplitude'] if 'amplitude' in sample['label']['seasonal'] else (sample['label']['seasonal']['segments'][0]['amplitude'] if 'segments' in sample['label']['seasonal'] else 0)
        answer = f"periodic fluctuation, each period is around {sample['label']['frequency']['period']:.2f} points, and the amplitude of the periodic fluctuation is around {amplitude:.2f}."
    features = {
        'type': sample['label']['seasonal']['type'],
        'period': sample['label']["frequency"]['period'],
        'amplitude': amplitude
    }
    ability_type = 'season'
    fields = {"seasonal": [0]}
    return question, answer, features, ability_type, fields

def generate_noise(sample):
    question = 'What are the noise characteristics of this time series? Please choose from ["noisy", "almost no noise"]. Answer format: noisy, the overall noise standard deviation is around 1.5.'
    answer = f"{sample['label']['noise']['type']}, the overall noise standard deviation is around {sample['label']['noise']['std']:.2f}."
    features = {
        'type': sample['label']['noise']['type'],
        'std': round(sample['label']['noise']['std'], 2)
    }
    ability_type = 'noise'
    fields = {"noise": [0]}
    return question, answer, features, ability_type, fields

# L1: local feature: local change type + statistic
def generate_local(sample):
    question = 'What are the local characteristic fluctuations of this time series? The optional types of local characteristic fluctuations include: ["' + '", "'.join(sorted(ALL_LOCAL_TYPES)) + '"]. You need to analyze all the characteristic fluctuations that appear in this time series and answer each type, position, and amplitude in the format. Different local characteristic fluctuations should be separated by semicolons. Answer format: shake, position around point 125, amplitude 135.03; small sudden decrease, position around point 102, amplitude 31.05.'
    answer = '; '.join([f"{i['type'] if type(i['type']) == str else i['type'][0]}, position around point {i['position_start']}, amplitude {i['amplitude']:.2f}" for i in sample['label']['local']])

    if len(sample['label']['local']) == 0:
        answer = 'No local characteristic fluctuations found.'

    features = []
    for i in sample['label']['local']:
        features.append({
            'type': i['type'] if type(i['type']) == str else [j for j in i['type']],
            'position': i['position_start'],
            'amplitude': round(i['amplitude'], 2)
        })
    ability_type = 'local'
    fields = {"local": [0]}
    return question, answer, features, ability_type, fields

# L2: correlation and cluster
def generate_shape_correlation(sample):
    if len(sample['label']['correlations']) == 0:
        raise NotImplementedError("ift not implemented for shape correlation with empty correlations")
    pairs = random.choice(sample['label']['correlations'])
    question = f'From the perspective of the overall trend, do {pairs["pair"][0]} and {pairs["pair"][1]} have very similar trend characteristics? Just answer yes or no. Answer format: Yes/No'
    if pairs['label']:
        answer = 'Yes'
    else:
        answer = 'No'
    features = pairs
    ability_type = 'shape-correlation'
    
    # Find indices of the metrics in the pairs
    pair_indices = []
    for metric in pairs["pair"]:
        if metric in sample['label']['cols']:
            pair_indices.append(sample['label']['cols'].index(metric))
    fields = {"trend": pair_indices}
    
    return question, answer, features, ability_type, fields

def generate_local_correlation(sample):
    # Choice and balance the label
    positive_pairs = [p for p in sample['label']['correlations'] if p['label']]
    negative_pairs = [p for p in sample['label']['correlations'] if not p['label']]

    if len(positive_pairs) and (random.random() > 0.5 or len(negative_pairs) == 0):
        pairs = random.choice(positive_pairs)
    else:
        pairs = random.choice(negative_pairs)
    question = f'From the perspective of local fluctuations, do {pairs["pair"][0]} and {pairs["pair"][1]} both have fluctuations near point {sample["label"]["position"]}? Answer yes or no, the types of their correlated fluctuations (if yes). Answer format: Yes. A, shake; B, upward spike.'
    if pairs['label']:
        answer = 'Yes. '
        # Find fluctuation type label
        def get_fluctuation_type(metric: str):
            for cluster in sample['label']['clusters']:
                if metric in cluster['cols']:
                    return cluster['col_idx'][cluster['cols'].index(metric)][1]
        answer += '; '.join([f"{m}, {get_fluctuation_type(m)}" for m in pairs['pair']])
    else:
        answer = 'No'
    features = copy.deepcopy(pairs)

    if pairs['label']:
        features['pair'] = [[m, get_fluctuation_type(m)] for m in pairs['pair']]
    ability_type = 'local-correlation'
    
    # Find indices of the metrics in the pairs
    pair_indices = []
    for metric in pairs["pair"]:
        if metric in sample['label']['cols']:
            pair_indices.append(sample['label']['cols'].index(metric))
    fields = {"local": pair_indices}

    return question, answer, features, ability_type, fields

def generate_shape_cluster(sample):
    cluster = random.choice(sample['label']['clusters'])
    question = f'From the perspective of the overall trend, which metric(s) have very similar trend characteristics with {random.choice(cluster["cols"])}? List the metrics that having the similar overall trend (including itself) in one sentence. Answer format: A, B, C.'
    answer = ', '.join(cluster['cols'])

    ability_type = 'shape-cluster'
    features = cluster
    
    # Find indices of the cluster metrics
    cluster_indices = []
    for metric in cluster['cols']:
        if metric in sample['label']['cols']:
            cluster_indices.append(sample['label']['cols'].index(metric))
    fields = {"trend": cluster_indices}

    return question, answer, features, ability_type, fields

def generate_local_cluster(sample):
    cluster = random.choice(sample['label']['clusters'])
    question = f'From the perspective of the position of local fluctuations, which metric(s) have very similar local fluctuation characteristics with {random.choice(cluster["cols"])}? The optional types of local characteristic fluctuations include: ["' + '", "'.join(sorted(ALL_LOCAL_TYPES)) + '"]. List the metrics that having the local fluctuations near the same positions (including itself) and the types of fluctuations in one sentence. Answer format: A, shake; B, upward spike; C, downward spike'
    answer = '; '.join([f"{i}, {cluster['col_idx'][idx][1]}" for idx, i in enumerate(cluster['cols'])])
    ability_type = 'local-cluster'
    features = cluster
    
    # Find indices of the cluster metrics
    cluster_indices = []
    for metric in cluster['cols']:
        if metric in sample['label']['cols']:
            cluster_indices.append(sample['label']['cols'].index(metric))
    fields = {"local": cluster_indices}

    return question, answer, features, ability_type, fields


# Generate dataset
def generate_qa(sample, filename: str, ts_idx: int):
    # Step 1. Check data type
    candidate_funcs, candidate_p = [], []
    mts_flag = False
    llm_flag = False
    if 'uts' in filename or 'single' in filename:
        candidate_funcs += [generate_trend, generate_season, generate_noise, generate_local]
        candidate_p += [0.2, 0.2, 0.2, 0.4]
        llm_flag = True
    if 'shape' in filename or 'trend' in filename:
        candidate_funcs += [generate_shape_correlation, generate_shape_cluster]
        candidate_p += [0.5, 0.5]
        mts_flag = True
        llm_flag = True
    if 'local' in filename or 'fluctuation' in filename:
        candidate_funcs += [generate_local_correlation, generate_local_cluster]
        candidate_p += [0.5, 0.5]
        mts_flag = True
        llm_flag = True

    # Step 2. Randomly choose a data type
    candidate_p = np.array(candidate_p) / np.sum(candidate_p)
    funcs = np.random.choice(candidate_funcs, size=1, replace=False, p=candidate_p)
    
    # Step 3. Augmentation
    instruction = ''
    original_timeseries = copy.deepcopy(sample['timeseries'])
    if mts_flag:
        timeseries = sample['timeseries']

        if llm_flag:
            cols = sample['label']['cols']
            instruction = f"You are a time series analysis expert. In a monitoring system of {sample['label']['situation']}, there are {len(timeseries)} metrics collected."
            for i in range(len(timeseries)):
                # Scalar
                cur_timeseries = np.array(timeseries[i])
                scaled_timeseries, cur_ts_prompt, _ = timeseries_encoding(cur_timeseries, ENCODING_METHOD)
                timeseries[i] = scaled_timeseries

                instruction += f"""\n "{sample['label']['cols'][i]}" is a time series with length of {len(timeseries[i])}: {cur_ts_prompt}"""
            instruction += ', please analyze the time series features and answer the following question:'
    else:
        # Scalar
        timeseries = sample['timeseries']
        scaled_timeseries, cur_ts_prompt, _ = timeseries_encoding(timeseries, ENCODING_METHOD)
        if llm_flag:
            timeseries = [scaled_timeseries]
            cols = [sample['label']['metric_name']]
            instruction = f"""You are a time series analysis expert. This time series is "{sample['label']['metric_name']}" from {sample['label']['situation']} with length of {len(timeseries[0])}: {cur_ts_prompt}, please analyze the time series features and answer the following question:"""
        else:
            timeseries = [scaled_timeseries]
            cols = ['TS1']
            instruction = f'You are a time series analysis expert. Here is a time series called TS1 of length {len(timeseries[0])}: {cur_ts_prompt}, please analyze the time series features and answer the following question:'

    answer = ''
    question = ''
    
    # Step 3. Generate QAs
    for idx, func in enumerate(funcs):
        cur_question, cur_answer, cur_attribute, cur_ability_type, cur_fields = func(sample)
        question += f'{cur_question}'
        answer += json.dumps({
            "answer": cur_answer,
            "attribute": cur_attribute,
            "ability_type": cur_ability_type,
            "cols": cols
        }, ensure_ascii=False)

    question += '\nNow, based on the question, please strictly follow the output format requirements and provide the answers. Your answer should be included in the \\answer{} tag. For example, if the answer is "Yes", you should respond with \\answer{Yes}.'

    # Replace specail characters in the metric names (cols)
    new_cols = []
    for col in cols:
        if any([s in col for s in [',', ';', '.', ':']]):
            new_col = col.replace(',', '').replace(';', '').replace('.', '').replace(':', '')
            print(f"Warning: replace special characters in metric name '{col}' -> '{new_col}'")
            question = question.replace(col, new_col)
            answer = answer.replace(col, new_col)
            new_cols.append(new_col)
        else:
            new_cols.append(col)
    cols = new_cols

    # Step 4. Return result
    return {
        'timeseries': timeseries,
        'original_timeseries': original_timeseries,
        'cols': cols,
        'question': question,
        'answer': answer,
        'type': cur_ability_type,
        'fields': cur_fields,
        'instruction': instruction,
        'sample': sample,
        'ts_idx': ts_idx
    }

def generate_dataset():
    print("Start generation...")
    # Preload labels per file and build weights
    file_labels = []  # list of (filename, labels_list)
    for file in LABEL_FILES:
        try:
            label = json.load(open(file))
        except Exception:
            print(f"Warning: Failed to load {file}, skipping.")
            label = []
        file_labels.append((file, label))

    # Compute sampling weights per file
    def resolve_weight_for_file(path: str) -> float:
        if not FILE_SAMPLING_RATIOS:
            return 1.0
        # Exact match first
        if path in FILE_SAMPLING_RATIOS:
            return float(FILE_SAMPLING_RATIOS[path])
        # Substring match fallback
        for k, w in FILE_SAMPLING_RATIOS.items():
            if k and k in path:
                return float(w)
        return 1.0

    files = [p for p, _ in file_labels]
    weights = [resolve_weight_for_file(p) for p, _ in file_labels]
    # Normalize negative or zero weights safeguarding
    weights = [w if w > 0 else 0.0 for w in weights]
    if sum(weights) == 0:
        weights = [1.0 for _ in weights]

    weights = np.array(weights) / np.sum(weights)
    print("File sampling weights:", weights)
    
    # Helper to sample a filename index by weights
    def sample_file_index() -> int:
        # random.choices supports weights directly
        idx = random.choices(range(len(files)), weights=weights, k=1)[0]
        return idx

    result = []
    evol_labels = []
    with tqdm(total=TARGET_CNT, desc='Generating samples') as pbar:
        while len(result) < TARGET_CNT:
            # 1) sample filename by weight
            fidx = sample_file_index()
            fname, labels = file_labels[fidx]
            if not labels:
                # If empty, fallback uniformly to any non-empty file
                non_empty = [i for i, (_, lab) in enumerate(file_labels) if lab]
                if not non_empty:
                    print("No labels available in any file. Abort.")
                    break
                fidx = random.choice(non_empty)
                fname, labels = file_labels[fidx]
            # 2) sample one sample from the chosen file
            sidx = random.randint(0, len(labels) - 1)
            sample = copy.deepcopy(labels[sidx])
            try:
                qa = generate_qa(sample, fname, len(result))
            except NotImplementedError as err:
                continue
            except Exception as err:
                # traceback.print_exc()
                continue
            if qa is not None:
                result.append(qa)
                
                # Create evol_label entry
                cur_attribute_pool = None
                if 'attribute_pool' in qa['sample']:
                    cur_attribute_pool = qa['sample']['attribute_pool']
                elif 'descriptions' in qa['sample']:
                    cur_attribute_pool = qa['sample']['descriptions']
                else:
                    cur_attribute_pool = [qa['sample']['label']]

                if len(cur_attribute_pool) != len(qa['timeseries']):
                    print(qa['sample'].keys())
                    raise ValueError(f"Attribute pool length {len(cur_attribute_pool)} does not match timeseries length {len(qa['timeseries'])} for sample idx {len(result)-1} in file {fname}")
                
                cur_label = {
                    "fields": qa['fields'],
                    "metrics": qa['cols'],
                    "corr_pool": None,
                    "attribute_pool": cur_attribute_pool,
                    "instruction": qa['instruction'],
                    "question": qa['question'],
                    "ts_idx": qa['ts_idx']
                }
                evol_labels.append(cur_label)
                
                pbar.update(1)
    
    print("Saving dataset...")
    os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
    os.makedirs(os.path.dirname(EVOL_LABEL_PATH), exist_ok=True)
    
    with open(OUTPUT_PATH, 'wt') as f:
        for item in result:
            item = {
                'input': item['instruction'] + item['question'],
                'output': item['answer'],
                'type': item['type'],
                'timeseries': timeseries_to_list(item['timeseries']),
            }
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    
    with open(EVOL_LABEL_PATH, 'wt') as f:
        json.dump(evol_labels, f, ensure_ascii=False, indent=4)
    
    print(f'Finished! {len(result)} samples saved to {OUTPUT_PATH}.')


if __name__ == '__main__':
    generate_dataset()
