#!/usr/bin/env python3
import argparse
import re
import shutil
import sys
import traceback
import yaml
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import optuna
import pandas as pd
import numpy as np
from scipy import interpolate
from lib import run_code, limit_csv_precision
from plot_result import plot_csv_file


class OptunaSimulationRunner:
    def __init__(self, working_dir: Path):
        self.working_dir = working_dir

    def run_simulation_from_code(self, code: str) -> Tuple[bool, str, str]:
        temp_file = self.working_dir / f'temp_optuna_code_{datetime.now().strftime("%H%M%S_%f")}.py'
        try:
            temp_file.write_text(code)
            (_, success, output) = run_code(temp_file, self.working_dir)
            if success:
                return (True, output, '')
            else:
                return (False, '', output)
        finally:
            if temp_file.exists():
                temp_file.unlink()


def parse_parameters(param_string: str) -> List[Tuple[str, float, float]]:
    if not param_string:
        raise ValueError('Parameter string cannot be empty')
    pattern = '\\(([^,]+),([^,]+),([^)]+)\\)'
    matches = re.findall(pattern, param_string)
    if not matches:
        raise ValueError("No valid parameter tuples found. Expected format: '(name,min,max)'")
    parameters = []
    for match in matches:
        param_name = match[0].strip()
        min_str = match[1].strip()
        max_str = match[2].strip()
        if not param_name:
            raise ValueError('Parameter name cannot be empty')
        try:
            min_val = float(min_str)
            max_val = float(max_str)
        except ValueError as e:
            raise ValueError(f'Invalid numeric values in parameter ({param_name},{min_str},{max_str}): {e}')
        if min_val >= max_val:
            raise ValueError(
                f'Min value must be less than max value for parameter {param_name}: {min_val} >= {max_val}'
            )
        parameters.append((param_name, min_val, max_val))
    return parameters


def parameterize_code(code: str, parameters: Dict[str, float]) -> str:
    lines = code.split('\n')
    begin_idx = None
    end_idx = None
    for i, line in enumerate(lines):
        if 'BEGIN PARAMS' in line:
            begin_idx = i
        elif 'END PARAMS' in line:
            end_idx = i
            break
    if begin_idx is None or end_idx is None:
        raise ValueError("Code must contain 'BEGIN PARAMS' and 'END PARAMS' markers")
    found_params = set()
    modified_lines = lines.copy()
    for i in range(begin_idx + 1, end_idx):
        line = lines[i]
        for param_name, param_value in parameters.items():
            pattern = f'^(\\s*{re.escape(param_name)}\\s*=\\s*)[^#\\n]*(.*)$'
            match = re.match(pattern, line)
            if match:
                prefix = match.group(1)
                suffix = match.group(2)
                modified_lines[i] = f'{prefix}{param_value}{suffix}'
                found_params.add(param_name)
    missing_params = set(parameters.keys()) - found_params
    if missing_params:
        raise ValueError(f'Parameters not found in code: {sorted(missing_params)}')
    return '\n'.join(modified_lines)


def compute_record_l1_distance(
    source_csv: str, target_record_path: str, target_column: str, source_column: str
) -> float:
    try:
        target_df = pd.read_csv(target_record_path)
        if 't' not in target_df.columns:
            raise ValueError("Target record must contain 't' column for time")
        if target_column not in target_df.columns:
            raise ValueError(f"Target column '{target_column}' not found in target record")
        from io import StringIO

        source_df = pd.read_csv(StringIO(source_csv))
        if 't' not in source_df.columns:
            raise ValueError("Source record must contain 't' column for time")
        if source_column not in source_df.columns:
            raise ValueError(f"Source column '{source_column}' not found in source record")
        target_times = target_df['t'].values
        target_values = target_df[target_column].values
        source_times = source_df['t'].values
        source_values = source_df[source_column].values
        if np.any(~np.isfinite(np.asarray(target_values))):
            raise ValueError('Target record contains NaN or infinite values')
        if np.any(~np.isfinite(np.asarray(source_values))):
            raise ValueError('Source record contains NaN or infinite values')
        if len(target_times) < 2:
            raise ValueError('Target record must have at least 2 time points')
        interp_func = interpolate.interp1d(
            target_times,
            target_values,
            kind='linear',
            bounds_error=False,
            fill_value=(float(target_values[0]), float(target_values[-1])),
        )
        interpolated_target = interp_func(source_times)
        l1_distances = np.abs(source_values - interpolated_target)
        mean_l1_distance = np.mean(l1_distances)
        return float(mean_l1_distance)
    except Exception as e:
        raise ValueError(f'Error computing L1 distance: {e}')


class OptunaOptimizer:
    def __init__(
        self,
        source_file: Path,
        parameters: List[Tuple[str, float, float]],
        fitness_type: str,
        n_trials: int,
        output_dir: Path,
        target_record: Optional[str] = None,
        target_column: Optional[str] = None,
        source_column: Optional[str] = None,
    ):
        self.source_file = source_file
        self.parameters = parameters
        self.fitness_type = fitness_type
        self.n_trials = n_trials
        self.output_dir = output_dir
        self.target_record = target_record
        self.target_column = target_column
        self.source_column = source_column
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.trials_dir = self.output_dir / 'trials'
        self.trials_dir.mkdir(exist_ok=True)
        self.best_trial_dir = self.output_dir / 'best-trial'
        self.best_trial_dir.mkdir(exist_ok=True)
        self.initial_code = self.source_file.read_text()
        self.runner = OptunaSimulationRunner(Path.cwd())
        self.trajectory = []
        self.current_best_value = float('inf')
        self.current_best_trial = None

    def objective(self, trial) -> float:
        trial_params = {}
        for param_name, min_val, max_val in self.parameters:
            trial_params[param_name] = trial.suggest_float(param_name, min_val, max_val)
        try:
            parameterized_code = parameterize_code(self.initial_code, trial_params)
            (success, csv_output, error) = self.runner.run_simulation_from_code(parameterized_code)
            if not success:
                print(f'Trial {trial.number}: Simulation failed - {error}')
                self.trajectory.append(
                    {
                        'trial': trial.number,
                        'params': trial_params.copy(),
                        'value': float('inf'),
                        'status': 'failed',
                        'error': error,
                    }
                )
                return float('inf')
            if self.fitness_type == 'record_l1_distance':
                fitness_value = self.compute_fitness_record_l1_distance(csv_output)
            else:
                raise ValueError(f'Unknown fitness type: {self.fitness_type}')
            self.save_trial_results(trial.number, trial_params, fitness_value, parameterized_code, csv_output)
            if fitness_value < self.current_best_value:
                self.current_best_value = fitness_value
                self.current_best_trial = trial.number
                print(f'*** NEW BEST! Trial {trial.number}: fitness = {fitness_value:.6f} ***')
                self.save_best_results_immediate(trial.number, trial_params, fitness_value)
                self.plot_best_record()
            else:
                print(f'Trial {trial.number}: fitness = {fitness_value:.6f}')
            self.trajectory.append(
                {'trial': trial.number, 'params': trial_params.copy(), 'value': fitness_value, 'status': 'success'}
            )
            return fitness_value
        except Exception as e:
            print(f'Trial {trial.number}: Exception - {e}')
            self.trajectory.append(
                {
                    'trial': trial.number,
                    'params': trial_params.copy(),
                    'value': float('inf'),
                    'status': 'failed',
                    'error': str(e),
                }
            )
            return float('inf')

    def compute_fitness_record_l1_distance(self, csv_output: str) -> float:
        if not all([self.target_record, self.target_column, self.source_column]):
            raise ValueError(
                'record_l1_distance fitness requires --target-record, --target-column, and --source-column'
            )
        assert self.target_record is not None
        assert self.target_column is not None
        assert self.source_column is not None
        return compute_record_l1_distance(csv_output, self.target_record, self.target_column, self.source_column)

    def save_trial_results(
        self, trial_number: int, params: Dict[str, float], fitness_value: float, code: str, csv_output: str
    ):
        trial_dir = self.trials_dir / str(trial_number)
        trial_dir.mkdir(exist_ok=True)
        with open(trial_dir / 'parameters.yaml', 'w') as f:
            yaml.dump(params, f, default_flow_style=False)
        with open(trial_dir / 'value.txt', 'w') as f:
            f.write(str(fitness_value))
        with open(trial_dir / 'program.py', 'w') as f:
            f.write(code)
        if csv_output:
            with open(trial_dir / 'output_record.csv', 'w') as f:
                f.write(limit_csv_precision(csv_output))

    def save_best_results_immediate(self, trial_number: int, params: Dict[str, float], fitness_value: float):
        with open(self.best_trial_dir / 'best_value.txt', 'w') as f:
            f.write(str(fitness_value))
        with open(self.best_trial_dir / 'best_params.yaml', 'w') as f:
            yaml.dump(params, f, default_flow_style=False)
        trial_dir = self.trials_dir / str(trial_number)
        if trial_dir.exists():
            if (trial_dir / 'program.py').exists():
                shutil.copy2(trial_dir / 'program.py', self.best_trial_dir / 'best_program.py')
            if (trial_dir / 'output_record.csv').exists():
                shutil.copy2(trial_dir / 'output_record.csv', self.best_trial_dir / 'best_record.csv')

    def plot_best_record(self):
        best_record_path = self.best_trial_dir / 'best_record.csv'
        if best_record_path.exists():
            try:
                (_, success, message) = plot_csv_file(best_record_path, force=1, max_sampled_rows=1000)
                if success:
                    print(f'  └── Best record plotted: {message}')
                else:
                    print(f'  └── Plot failed: {message}')
            except Exception as e:
                print(f'  └── Plot error: {e}')
        else:
            print('  └── No best record found to plot')

    def save_best_results(self, study):
        best_trial = study.best_trial
        with open(self.best_trial_dir / 'best_value.txt', 'w') as f:
            f.write(str(best_trial.value))
        with open(self.best_trial_dir / 'best_params.yaml', 'w') as f:
            yaml.dump(best_trial.params, f, default_flow_style=False)
        best_trial_dir = self.trials_dir / str(best_trial.number)
        if best_trial_dir.exists():
            if (best_trial_dir / 'program.py').exists():
                shutil.copy2(best_trial_dir / 'program.py', self.best_trial_dir / 'best_program.py')
            if (best_trial_dir / 'output_record.csv').exists():
                shutil.copy2(best_trial_dir / 'output_record.csv', self.best_trial_dir / 'best_record.csv')

    def save_metadata(self, study):
        metadata = {
            'optimization_info': {
                'total_trials': len(study.trials),
                'best_trial': study.best_trial.number,
                'best_value': study.best_trial.value,
                'best_params': study.best_trial.params,
                'source_file': str(self.source_file),
                'fitness_type': self.fitness_type,
                'parameters_spec': [
                    {'name': name, 'min': min_val, 'max': max_val} for (name, min_val, max_val) in self.parameters
                ],
            },
            'trajectory': self.trajectory,
        }
        with open(self.output_dir / 'metadata.yaml', 'w') as f:
            yaml.dump(metadata, f, default_flow_style=False, allow_unicode=True)

    def run_optimization(self):
        print(f'Starting Optuna optimization with {self.n_trials} trials')
        print(f'Parameters: {[name for (name, _, _) in self.parameters]}')
        print(f'Fitness type: {self.fitness_type}')
        print(f'Output directory: {self.output_dir}')
        print()
        study = optuna.create_study(direction='minimize')
        study.optimize(self.objective, n_trials=self.n_trials)
        self.save_best_results(study)
        self.save_metadata(study)
        print('\nOptimization complete!')
        print(f'Best trial: {study.best_trial.number}')
        print(f'Best value: {study.best_trial.value:.6f}')
        print(f'Best params: {study.best_trial.params}')
        print(f'Results saved to: {self.output_dir}')


def sanitize_name(name: str) -> str:
    name = Path(name).stem
    sanitized = re.sub('[^a-zA-Z0-9_-]', '_', name)
    sanitized = re.sub('_+', '_', sanitized)
    return sanitized.strip('_')[:50]


def main():
    parser = argparse.ArgumentParser(
        description='Optuna-based baseline optimization for Python simulation systems',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog='',
    )
    parser.add_argument('source', help='Source Python simulation file (e.g., 2_2.py)')
    parser.add_argument('--parameters', required=True, help='Parameter space: "(x,-10,10),(y,-3,3)"')
    parser.add_argument('--fitness-type', required=True, choices=['record_l1_distance'], help='Fitness function type')
    parser.add_argument('--n-trials', type=int, default=100, help='Number of Optuna trials (default: 100)')
    parser.add_argument('--target-record', help='Target CSV file path (for record_l1_distance)')
    parser.add_argument('--target-column', help='Target column name (for record_l1_distance)')
    parser.add_argument('--source-column', help='Source column name (for record_l1_distance)')
    parser.add_argument('--output-dir', help='Output directory (default: auto-generated)')
    args = parser.parse_args()
    source_file = Path(args.source)
    try:
        parameters = parse_parameters(args.parameters)
        print(f'Parsed parameters: {parameters}')
    except ValueError as e:
        print(f'ERROR: Invalid parameters: {e}')
        sys.exit(1)
    if args.fitness_type == 'record_l1_distance':
        if not all([args.target_record, args.target_column, args.source_column]):
            print('ERROR: record_l1_distance fitness requires --target-record, --target-column, and --source-column')
            sys.exit(1)
        if not Path(args.target_record).exists():
            print(f'ERROR: Target record file not found: {args.target_record}')
            sys.exit(1)
    if args.output_dir:
        output_dir = Path(args.output_dir)
    else:
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        output_dir = Path(f'output/bl_optuna/{timestamp}-{args.fitness_type}')
    try:
        optimizer = OptunaOptimizer(
            source_file=source_file,
            parameters=parameters,
            fitness_type=args.fitness_type,
            n_trials=args.n_trials,
            output_dir=output_dir,
            target_record=args.target_record,
            target_column=args.target_column,
            source_column=args.source_column,
        )
        optimizer.run_optimization()
    except KeyboardInterrupt:
        print('\n=== OPTIMIZATION INTERRUPTED ===')
    except Exception as e:
        print('\n=== OPTIMIZATION FAILED ===')
        print(f'ERROR: {e}')
        traceback.print_exc()


if __name__ == '__main__':
    main()
