import pandas as pd
import argparse
import numpy as np
from datetime import datetime
import os

def format_value(value):
    """Format numerical values according to specified rules.
    
    Args:
        value: The numerical value to format.
        
    Returns:
        Formatted string representation.
    """
    # Return "NaN" for missing values or values > 100
    if pd.isna(value) or abs(value) > 100:
        return "NaN"
    
    # Make negative values positive (since lower is better in this case)
    value = abs(value)
    
    # Handle values smaller than 0.01 with simplified notation
    if 0 < value < 0.01:
        # Move two decimal places backward and add - symbol
        # Always use 2 digits for these small values
        value_scaled = value * 100
        return f"{value_scaled:.2f}-"
    # Format based on magnitude
    elif value >= 10:
        return f"{value:.2f}"
    else:
        return f"{value:.3f}"

def highlight_best(values, is_lower_better=True):
    """Add LaTeX highlighting to the best two values.
    
    Args:
        values: List of formatted value strings.
        is_lower_better: Whether lower values are better.
        
    Returns:
        List of values with LaTeX highlighting for the best two.
    """
    # Convert to numeric for comparison, handling NaN and special case for values ending with '-'
    numeric_values = []
    for i, v in enumerate(values):
        if v == "NaN" or v == "---":
            numeric_values.append((i, float('inf') if is_lower_better else float('-inf')))
        elif v.endswith('-'):
            try:
                # Handle values ending with '-' by multiplying by 10^-2
                numeric_value = float(v[:-1]) * 1e-2
                numeric_values.append((i, numeric_value))
            except ValueError:
                numeric_values.append((i, float('inf') if is_lower_better else float('-inf')))
        else:
            try:
                numeric_values.append((i, float(v)))
            except ValueError:
                numeric_values.append((i, float('inf') if is_lower_better else float('-inf')))
    
    # Sort based on whether lower or higher is better
    sorted_indices = sorted(numeric_values, key=lambda x: x[1] if is_lower_better else -x[1])
    
    # Get indices of best two values (excluding NaN)
    best_indices = []
    for idx, val in sorted_indices:
        if values[idx] != "NaN" and values[idx] != "---":
            best_indices.append(idx)
            if len(best_indices) == 2:
                break
    
    # Add highlighting
    highlighted_values = values.copy()
    for idx in best_indices:
        highlighted_values[idx] = f"\\textbf{{{values[idx]}}}"
    
    return highlighted_values

def get_latest_model(df, bolt_model=None, use_lora=None, fine_tune=None, 
                    aggregation_strategy=None, has_regressors=None, larger_than_one=None, name_tag=None):
    """Get the latest model matching the specified criteria.
    
    Args:
        df: DataFrame containing model results.
        bolt_model: Bolt model value to filter for.
        use_lora: Whether to filter for models using LoRA.
        fine_tune: Whether to filter for fine-tuned models.
        aggregation_strategy: Aggregation strategy to filter for.
        has_regressors: Whether the model should have regressor types.
        
    Returns:
        Row of the latest matching model or None if no match.
    """
    # Start with basic filter - all rows
    mask = pd.Series(True, index=df.index)
    
    if name_tag is not None:
        mask &= df['model'] == name_tag

    # Filter by regressor types if specified
    if has_regressors is not None:
        if has_regressors:
            mask &= ~df['regressor_types'].isna() & (df['regressor_types'] != "")
        else:
            mask &= df['regressor_types'].isna() | (df['regressor_types'] == "")

        # Filter for models with more than one regressor type (comma-separated)
        def check_length(regressor_string):
            if pd.isna(regressor_string) or regressor_string == "":
                return False
            return len(regressor_string.split(',')) > 1
        
        if larger_than_one is not None:
            if larger_than_one:
                # Must have more than one regressor type
                mask &= df['regressor_types'].apply(check_length)
            else:
                # Must have exactly one regressor type
                mask &= ~df['regressor_types'].apply(check_length)

    # Add additional filters if specified
    if bolt_model is not None:
        if pd.isna(bolt_model):
            mask &= df['bolt_model'].isna()
        else:
            mask &= df['bolt_model'] == bolt_model
    else:
        mask &= df['bolt_model'].isna()
    
    if use_lora is not None:
        mask &= df['use_lora'] == use_lora
    
    if fine_tune is not None:
        mask &= df['fine_tune'] == fine_tune
    
    if aggregation_strategy is not None:
        if pd.isna(aggregation_strategy):
            mask &= df['aggregation_strategy_name'].isna()
        else:
            mask &= df['aggregation_strategy_name'] == aggregation_strategy
    
    # Filter and sort by timestamp
    filtered = df[mask].sort_values('timestamp', ascending=False)
    
    # Return the latest model if any match was found
    if not filtered.empty:
        return filtered.iloc[0]
    return None

def generate_latex_table(csv_file, dataset_name=None):
    """Generate LaTeX table rows from results CSV.
    
    Args:
        csv_file: Path to the results CSV file.
        dataset_name: Optional name of the dataset to filter for.
        
    Returns:
        LaTeX table rows as a string.
    """
    # Read the CSV file
    df = pd.read_csv(csv_file)
    
    # Filter by dataset if specified
    if dataset_name:
        df = df[df['dataset'] == dataset_name]
    else:
        # Use the first dataset in the file
        dataset_name = df['dataset'].iloc[0]
    
    # Define metrics to include
    metrics = ['MSE', 'SMAPE', 'MASE', 'MAPE']
    metrics = ['MASE', 'MAPE']

    # Short dataset name for LaTeX
    short_name = dataset_name.split('/')[0].capitalize()
    if '/' in dataset_name:
        freq = dataset_name.split('/')[1]
        short_name = f"{short_name}.{freq}"
    
    # Initialize the LaTeX table content
    latex_rows = [f"\\multirow{{2}}{{*}}{{\\rot{{\\tiny{{{short_name}}}}}}}"]
    # Models we want to extract in this specific order
    model_configs = [
        # # HopFormer models (bolt_model not None, regressor_types not empty)
        {"name": "HopFormer 0-shot", "bolt_model": "bolt_small", 
         "fine_tune": False, "use_lora": False, "has_regressors": True, "aggregation_strategy_name": "spa"},
        {"name": "HopFormer FT", "bolt_model": "bolt_small", 
         "fine_tune": True, "use_lora": False, "has_regressors": True, "aggregation_strategy_name": "spa"},
        {"name": "HopFormer LoRA", "bolt_model": "bolt_small", 
         "fine_tune": True, "use_lora": True, "has_regressors": True, "aggregation_strategy_name": "spa"},
        

        # CrossSectional models (bolt_model is None)
        {"name": "CS SPA", "bolt_model": None, 
         "aggregation_strategy_name": "spa", "has_regressors": True, "larger_than_one": True},
        {"name": "CS Linear", "bolt_model": None, 
         "aggregation_strategy_name": "linear", "has_regressors": True, "larger_than_one": True},
        {"name": "CS Best", "bolt_model": None, 
         "aggregation_strategy_name": "singlebest", "has_regressors": True, "larger_than_one": True},
        {"name": "CS Equal", "bolt_model": None, 
         "aggregation_strategy_name": "equal", "has_regressors": True, "larger_than_one": True},
        
        # Chronos models (bolt_model not None, empty regressor_types)
        {"name": "Chronos 0-shot", "bolt_model": "bolt_small", 
         "fine_tune": False, "use_lora": False, "has_regressors": False},
        {"name": "Chronos FT", "bolt_model": "bolt_small", 
         "fine_tune": True, "use_lora": False, "has_regressors": False},
        {"name": "Chronos LoRA", "bolt_model": "bolt_small", 
         "fine_tune": True, "use_lora": True, "has_regressors": False},

         # others models
         {"name": "PatchTST", "bolt_model": None, 
         "aggregation_strategy_name": "equal", "has_regressors": True, "larger_than_one": False, "name_tag": "PatchTST"},
         {"name": "TemporalFusionTransformer", "bolt_model": None, 
         "aggregation_strategy_name": "equal", "has_regressors": True, "larger_than_one": False, "name_tag": "TemporalFusionTransformer"},
        #  {"name": "SimpleFeedForward", "bolt_model": None, 
        #  "aggregation_strategy_name": "equal", "has_regressors": True, "larger_than_one": False, "name_tag": "SimpleFeedForward"},
        #  {"name": "DLinear", "bolt_model": None, 
        #  "aggregation_strategy_name": "equal", "has_regressors": True, "larger_than_one": False, "name_tag": "DLinear"},
         {"name": "AutoARIMA", "bolt_model": None, 
         "aggregation_strategy_name": "equal", "has_regressors": True, "larger_than_one": False, "name_tag": "AutoARIMA"},
        #  {"name": "AutoCES", "bolt_model": None, 
        #  "aggregation_strategy_name": "equal", "has_regressors": True, "larger_than_one": False, "name_tag": "AutoCES"},
         {"name": "AutoETS", "bolt_model": None, 
         "aggregation_strategy_name": "equal", "has_regressors": True, "larger_than_one": False, "name_tag": "AutoETS"},
    ]
    
    # Extract latest models for each configuration
    latest_models = []
    
    for config in model_configs:
        # Get the model based on criteria
        model = get_latest_model(
            df,
            bolt_model=config.get("bolt_model"),
            use_lora=config.get("use_lora"),
            fine_tune=config.get("fine_tune"),
            aggregation_strategy=config.get("aggregation_strategy_name"),
            has_regressors=config.get("has_regressors"),
            larger_than_one=config.get("larger_than_one", None),
            name_tag=config.get("name_tag", None)
        )
        # print(f"name: {config['name']}, \n{model}\n")
        
        latest_models.append(model)
    
    # Add 6 empty columns as requested
    # for _ in range(4):
    #     latest_models.append(None)
    
    # Generate a row for each metric
    for metric_name in metrics:
        mean_col = f"{metric_name}_mean"
        values = []
        
        for model in latest_models:
            if model is not None and mean_col in model:
                # Extract and format the mean value
                value = model[mean_col]
                values.append(format_value(value))
            else:
                values.append("---")
        
        # Highlight the best two values
        highlighted_values = highlight_best(values, is_lower_better=True)
        
        # Create the LaTeX row
        latex_row = f" & {metric_name} & " + " & ".join(highlighted_values) + " \\\\"
        latex_rows.append(latex_row)
    
    # Add the midrule
    latex_rows.append("\\midrule")
    
    return '\n'.join(latex_rows)

def main():
    """Main function to process command line arguments and generate LaTeX table."""
    parser = argparse.ArgumentParser(description="Convert results CSV to LaTeX table")
    parser.add_argument("--csv", type=str, required=True, help="Path to results CSV file")
    parser.add_argument("--dataset", type=str, help="Dataset name to filter for")
    parser.add_argument("--output", type=str, help="Output file path (default: print to console)")
    args = parser.parse_args()
    
    latex_content = generate_latex_table(args.csv, args.dataset)
    
    if args.output:
        with open(args.output, 'w') as f:
            f.write(latex_content)
        print(f"LaTeX table written to {args.output}")
    else:
        print(latex_content)

if __name__ == "__main__":
    main()