import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from fastdtw import fastdtw
from math import factorial
from scipy.signal import find_peaks

def savitzky_golay(y, window_size, order, deriv=0, rate=1):
    """Savitzky-Golay smoothing filter"""
    try:
        window_size = np.abs(np.int32(window_size))
        order = np.abs(np.int32(order))
    except Exception as e:
        raise ValueError("window_size and order have to be of type int", e)
    
    window_size = min(window_size, len(y) // 4 * 2 + 1)
    
    assert window_size > 0, "Window size must be greater than zero"
    assert window_size % 2 == 1, "Window size must be odd"
    assert window_size >= order + 2, "Window size is too small for the polynomials order"
    
    order_range = range(order + 1)
    half_window = (window_size - 1) // 2
    
    # precompute coefficients
    b = np.asmatrix([[k**i for i in order_range] for k in range(-half_window, half_window+1)])
    m = np.linalg.pinv(b).A[deriv] * rate**deriv * factorial(deriv)
    
    # pad the signal at the extremes with values taken from the signal itself
    firstvals = y[1:half_window+1][::-1]
    lastvals = y[-half_window-1:-1][::-1]
    y = np.concatenate((firstvals, y, lastvals))
    return np.convolve(m[::-1], y, mode='valid')

def load_target_curve(dat_file_path):
    # Read .dat file (first column: x, second column: y)
    data = np.loadtxt(dat_file_path)
    x = data[:, 0]
    y = data[:, 1]
    
    # Apply Savitzky-Golay smoothing (same as original processing)
    smooth_window = 50
    raw_prefix = 10
    
    if len(y) <= raw_prefix + smooth_window + 1:
        y_mean = y  # Data too short, skip smoothing
    else:
        smoothed_part = savitzky_golay(y[raw_prefix:], smooth_window + 1, 3)
        y_mean = np.concatenate([y[:raw_prefix], smoothed_part])
    
    return x, y_mean

def load_multiple_target_curves(dat_files):
    all_curves = []
    
    for dat_file in dat_files:
        x, y = load_target_curve(dat_file)
        all_curves.append((x, y))
    
    return all_curves

def standardize_curves_to_common_steps(all_curves):
    max_steps = 0
    for x, y in all_curves:
        max_steps = max(max_steps, int(max(x)))
    
    common_steps = np.arange(0, max_steps + 1)
    
    standardized_curves = []
    for i, (x, y) in enumerate(all_curves):
        y_interp = np.interp(common_steps, x, y)
        standardized_curves.append(y_interp)
    
    return standardized_curves, common_steps

def enhanced_actr_model(a, c, tau, s, k, steps=None, review_intervals=None):
    if steps is None:
        steps = np.arange(0, 2000)
    if review_intervals is None:
        review_intervals = [(0, 100), (300, 400), (600, 700), (900, 1000), (1200, 1300)]
    
    # Generate review points
    review_points = []
    for start, end in review_intervals:
        review_points += list(range(start, end, 2))  # Review every 2 steps
    
    review_points = sorted(review_points)
    
    # Calculate activation and recall probability using recursive decay rates
    activation = []
    decay_rates = {}  # Store decay rate for each practice
    
    for t in steps:
        # Find all practices that occurred before time t
        relevant_practices = [rp for rp in review_points if rp < t]
        
        if len(relevant_practices) == 0:
            activation.append(-np.inf)
            continue
        
        # Calculate activation contributions
        contributions = []
        current_activation_for_next = None
        
        for i, practice_time in enumerate(relevant_practices):
            dt = (t - practice_time) * k  
            
            if dt <= 0:
                continue
                
            # Determine decay rate for this practice
            if i == 0:
                di = a
            else:
                if i-1 in decay_rates:
                    di = decay_rates[i-1]
                else:
                    di = a
            
            # Add contribution
            try:
                contribution = dt**(-di)
                if np.isfinite(contribution) and contribution > 0:
                    contributions.append(contribution)
            except (OverflowError, ZeroDivisionError):
                contributions.append(1e-10)
        
        # Calculate current activation
        if contributions:
            current_activation = np.log(np.sum(contributions))
            activation.append(current_activation)
            
            # Update decay rate for next practice if this is a practice time
            if t in review_points:
                practice_index = review_points.index(t)
                try:
                    new_decay_rate = c * np.exp(current_activation) + a
                    new_decay_rate = np.clip(new_decay_rate, 0.1, 3.0)
                    decay_rates[practice_index] = new_decay_rate
                except (OverflowError, RuntimeWarning):
                    decay_rates[practice_index] = a + 0.5
        else:
            activation.append(-np.inf)
    
    # Convert activation to recall probability
    activation = np.array(activation)
    
    # Handle infinite activations
    activation = np.where(activation == -np.inf, -10, activation)
    activation = np.where(activation == np.inf, 10, activation)
    
    # Calculate recall probability: p = 1 / (1 + exp(-(m - tau) / s))
    recall_prob = 1 / (1 + np.exp(-(activation - tau) / s))
    
    return recall_prob

def chi_square_objective(target, predicted, n_points=None):
    """Calculate chi-square statistic as used in the paper"""
    if n_points is None:
        n_points = len(target)
    
    chi2 = 0
    valid_points = 0
    
    for i in range(len(target)):
        # Ensure predicted values are within valid range
        pred_val = np.clip(predicted[i], 0.001, 0.999)
        
        # Chi-square formula: (observed - expected)^2 / (expected * (1 - expected))
        chi2_contrib = (target[i] - pred_val)**2 / (pred_val * (1 - pred_val))
        
        if np.isfinite(chi2_contrib):
            chi2 += chi2_contrib
            valid_points += 1
    
    return chi2

def calculate_adjusted_rmsd(target, predicted, n_params):
    """Calculate adjusted RMSD as described in the paper"""
    n = len(target)
    if n <= n_params:
        return np.sqrt(mean_squared_error(target, predicted))
    
    mse = mean_squared_error(target, predicted)
    adjusted_rmse = np.sqrt(mse * n / (n - n_params))
    return adjusted_rmse

def comprehensive_metrics(target, predicted, review_intervals=None, n_params=5):
    """Calculate comprehensive similarity metrics"""
    metrics = {}
    
    # Original metrics (preserved)
    metrics['mae'] = mean_absolute_error(target, predicted)
    metrics['mse'] = mean_squared_error(target, predicted)
    
    # Handle correlation calculation
    try:
        correlation = np.corrcoef(target, predicted)[0, 1]
        if np.isnan(correlation):
            correlation = 0.0
        metrics['pearson_corr'] = correlation
    except:
        metrics['pearson_corr'] = 0.0
    
    # DTW distance
    try:
        dtw_distance, _ = fastdtw(target, predicted)
        metrics['dtw_distance'] = dtw_distance
        metrics['dtw_distance_normalized'] = dtw_distance / len(target)
    except:
        metrics['dtw_distance'] = float('inf')
        metrics['dtw_distance_normalized'] = float('inf')
    
    # New metrics from paper
    metrics['chi_square'] = chi_square_objective(target, predicted)
    metrics['r_squared'] = r2_score(target, predicted)
    metrics['rmsd_adjusted'] = calculate_adjusted_rmsd(target, predicted, n_params)
    
    return metrics

def optimize_parameters(all_curves, review_intervals_list):
    
    standardized_curves, common_steps = standardize_curves_to_common_steps(all_curves)
    
    def chi_square_objective_function(params):
        a, c, tau, s, k = params
        total_chi2 = 0
        n_curves = len(standardized_curves)
        
        try:
            for i, target_curve in enumerate(standardized_curves):
                review_intervals = review_intervals_list[i]
                model_curve = enhanced_actr_model(a, c, tau, s, k, common_steps, review_intervals)
                chi2 = chi_square_objective(target_curve, model_curve)
                total_chi2 += chi2
            
            avg_chi2 = total_chi2 / n_curves
            
            penalty = 0
            
            return avg_chi2 + penalty
            
        except Exception as e:
            print(f"Error in objective function: {e}")
            return 1e6
    
    bounds = [(0.05, 0.5), (0.01, 1.0), (-2.0, 2.0), (0.1, 1.0), (0.1, 10.0)]
    x0 = [0.177, 0.217, -0.704, 0.255, 1.0]  # Based on paper values
    
    result = minimize(chi_square_objective_function, x0, bounds=bounds, method='L-BFGS-B')
    
    
    return result.x, common_steps, standardized_curves

def enhanced_phase_similarity(target_phase, fitted_phase, phase_type):
    if len(target_phase) == 0 or len(fitted_phase) == 0:
        return 0.0
        
    if phase_type == "training":
        peak_similarity = 1 - abs(np.max(target_phase) - np.max(fitted_phase))
        peak_similarity = max(0, peak_similarity)  
        try:
            trend_correlation = np.corrcoef(target_phase, fitted_phase)[0,1]
            if np.isnan(trend_correlation):
                trend_correlation = 0
        except:
            trend_correlation = 0
        return 0.6 * peak_similarity + 0.4 * max(0, trend_correlation)
        
    elif phase_type == "decay":
        if len(target_phase) <= 1:
            return 0.0
        target_slope = (target_phase[-1] - target_phase[0]) / len(target_phase)
        fitted_slope = (fitted_phase[-1] - fitted_phase[0]) / len(fitted_phase)
        slope_similarity = 1 - abs(target_slope - fitted_slope)
        slope_similarity = max(0, slope_similarity)
        
        try:
            shape_correlation = np.corrcoef(target_phase, fitted_phase)[0,1]
            if np.isnan(shape_correlation):
                shape_correlation = 0
        except:
            shape_correlation = 0
        return 0.2 * slope_similarity + 0.8 * max(0, shape_correlation)
        
    elif phase_type == "stability":
        level_similarity = 1 - abs(np.mean(target_phase) - np.mean(fitted_phase))
        level_similarity = max(0, level_similarity)
        variance_similarity = 1 - abs(np.std(target_phase) - np.std(fitted_phase))
        variance_similarity = max(0, variance_similarity)
        try:
            trend_correlation = np.corrcoef(target_phase, fitted_phase)[0,1]
            if np.isnan(trend_correlation):
                trend_correlation = 0
        except:
            trend_correlation = 0
        
        return 0.4 * level_similarity + 0.3 * variance_similarity + 0.3 * max(0, trend_correlation)
    
    return 0.0

def improved_phase_differentiated_evaluation(target, fitted, training_windows):
    results = {}
    
    training_indices = []
    for start, end in training_windows:
        training_indices.extend(range(start, min(end, len(target))))
    
    if training_indices:
        target_train = target[training_indices]
        fitted_train = fitted[training_indices]
        results['training_similarity'] = enhanced_phase_similarity(target_train, fitted_train, "training")
    else:
        results['training_similarity'] = 0

    decay_scores = []
    for i in range(len(training_windows)-1):
        end_curr = training_windows[i][1]
        start_next = training_windows[i+1][0]
        if start_next > end_curr and end_curr < len(target) and start_next <= len(target):
            target_decay = target[end_curr:start_next]
            fitted_decay = fitted[end_curr:start_next]
            if len(target_decay) > 0 and len(fitted_decay) > 0:
                decay_score = enhanced_phase_similarity(target_decay, fitted_decay, "decay")
                decay_scores.append(decay_score)
    
    results['decay_similarity'] = np.mean(decay_scores) if decay_scores else 0
    
    late_period_start = training_windows[-1][1]
    if late_period_start < len(target):
        target_late = target[late_period_start:]
        fitted_late = fitted[late_period_start:]
        results['late_similarity'] = enhanced_phase_similarity(target_late, fitted_late, "stability")
    else:
        results['late_similarity'] = 0
    
    results['composite_score'] = (
        0.2 * results.get('training_similarity', 0) + 
        0.3 * results.get('decay_similarity', 0) + 
        0.5 * results.get('late_similarity', 0)
    )
    
    return results

def calculate_multiple_metrics(standardized_curves, fitted_curves, review_intervals_list):

    all_metrics = []
    

    for i, (target_curve, fitted_curve) in enumerate(zip(standardized_curves, fitted_curves)):
        review_intervals = review_intervals_list[i]

        metrics = comprehensive_metrics(target_curve, fitted_curve, review_intervals)
        
        phase_results = improved_phase_differentiated_evaluation(target_curve, fitted_curve, review_intervals)
        metrics.update(phase_results)
        
        all_metrics.append(metrics)
        for key, value in metrics.items():
            print(f"  {key}: {value:.4f}")
    
    overall_metrics = {}
    if all_metrics:
        metric_keys = all_metrics[0].keys()
        for key in metric_keys:
            overall_metrics[f'avg_{key}'] = np.mean([m[key] for m in all_metrics if np.isfinite(m[key])])
    
    return all_metrics, overall_metrics

def plot_multiple_comparison(all_curves, common_steps, fitted_curves, optimal_params, 
                           overall_metrics, save_path=None, save_path_pdf=None, 
                           review_intervals_list=None, dat_files=None, all_metrics=None):
    
    n_files = len(all_curves)
    if n_files == 1:
        rows, cols = 1, 1
    elif n_files == 2:
        rows, cols = 1, 2
    elif n_files <= 4:
        rows, cols = 2, 2
    elif n_files <= 6:
        rows, cols = 2, 3
    else:
        rows, cols = 3, 3 
    
    fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 5*rows))
    
    if n_files == 1:
        axes = [axes]
    elif rows == 1 or cols == 1:
        axes = axes.flatten()
    else:
        axes = axes.flatten()
    
    colors = ['blue', 'green', 'red', 'orange', 'purple', 'brown', 'pink', 'gray', 'cyan']
    
    for i, ((x, y), fitted_curve) in enumerate(zip(all_curves, fitted_curves)):
        ax = axes[i]
        color = colors[i % len(colors)]
        file_name = dat_files[i].split('.')[0] if dat_files else f'File {i+1}'
        
        ax.plot(x, y, '-', color='red', linewidth=2, label=f'{file_name}')
        
        x_min, x_max = int(min(x)), int(max(x))
        mask = (common_steps >= x_min) & (common_steps <= x_max)
        fitted_x = common_steps[mask]
        fitted_y = fitted_curve[mask]
        
        ax.plot(fitted_x, fitted_y, '--', color='blue', linewidth=2, 
                label=f'Fitted ACT-R Model', alpha=0.8)
        
        if review_intervals_list and i < len(review_intervals_list):
            for start, end in review_intervals_list[i]:
                if start <= x_max and end >= x_min:  
                    ax.axvspan(max(start, x_min), min(end, x_max), color='lightgreen', alpha=0.3)
        
        if all_metrics and i < len(all_metrics):
            file_metrics = all_metrics[i]
            metrics_text = f"MAE: {file_metrics.get('mae', 0):.3f} "
            metrics_text += f"χ²: {file_metrics.get('chi_square', 0):.1f} "
            metrics_text += f"r²: {file_metrics.get('r_squared', 0):.3f}"
            
            ax.text(0.02, 0.98, metrics_text, transform=ax.transAxes,
                   verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8),
                   fontsize=10)
        

        ax.set_xlabel('Steps', fontsize=12)
        ax.set_ylabel('Normalized Performance', fontsize=12)
        ax.legend(fontsize=10)
        ax.grid(True, alpha=0.3)
        ax.set_ylim(0, 1)
        ax.set_xlim(x_min, x_max)
    

    for j in range(n_files, len(axes)):
        axes[j].set_visible(False)
    
    fig.suptitle(f'Enhanced ACT-R Model Fitting Results\n'
                f'(a={optimal_params[0]:.3f}, c={optimal_params[1]:.3f}, '
                f'τ={optimal_params[2]:.3f}, s={optimal_params[3]:.3f}, k={optimal_params[4]:.3f})', 
                fontsize=16, fontweight='bold')
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.85) 
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    if save_path_pdf:
        plt.savefig(save_path_pdf, dpi=300, bbox_inches='tight')
    
    plt.show()

def main(dat_files, review_intervals_list, save_dir=None):

    all_curves = load_multiple_target_curves(dat_files)

    optimal_params, common_steps, standardized_curves = optimize_parameters(all_curves, review_intervals_list)
    print(f" a={optimal_params[0]:.4f}, c={optimal_params[1]:.4f}, "
          f"τ={optimal_params[2]:.4f}, s={optimal_params[3]:.4f}, k={optimal_params[4]:.4f}")

    fitted_curves = []
    for i, review_intervals in enumerate(review_intervals_list):
        fitted_curve = enhanced_actr_model(optimal_params[0], optimal_params[1], optimal_params[2], 
                                         optimal_params[3], optimal_params[4], common_steps, review_intervals)
        fitted_curves.append(fitted_curve)
    
    all_metrics, overall_metrics = calculate_multiple_metrics(standardized_curves, fitted_curves, review_intervals_list)
    
    print("\n=== Enhanced ACT-R ===")
    print(f"a (decay intercept): {optimal_params[0]:.4f}")
    print(f"c (decay scale): {optimal_params[1]:.4f}")
    print(f"τ (threshold): {optimal_params[2]:.4f}")
    print(f"s (noise): {optimal_params[3]:.4f}")
    print(f"k (time scale): {optimal_params[4]:.4f}")
    
    for metric, value in overall_metrics.items():
        if np.isfinite(value):
            print(f"{metric}: {value:.4f}")
    
    if save_dir:
        save_path = os.path.join(save_dir, "enhanced_actr_multiple_files_comparison.png")
        save_path_pdf = os.path.join(save_dir, "enhanced_actr_multiple_files_comparison.pdf")
    else:
        save_path = "enhanced_actr_multiple_files_comparison.png"
        save_path_pdf = "enhanced_actr_multiple_files_comparison.pdf"
    
    plot_multiple_comparison(all_curves, common_steps, fitted_curves, optimal_params,
                           overall_metrics, save_path, save_path_pdf, review_intervals_list, 
                           dat_files, all_metrics)
    
    return optimal_params, all_metrics, overall_metrics

if __name__ == "__main__":

    dat_files = [
        "LLAMA3_4_DP_100.dat",
        "LLAMA3_4_DP_200.dat",
        "LLAMA3_4_DP_400.dat",
        "LLAMA3_8_DP_100.dat",
        "LLAMA3_8_DP_200.dat",
        "LLAMA3_8_DP_400.dat",
    ]


    review_intervals_list = [
        [(0, 90), (180, 280), (380, 480), (580, 680), (780, 880)],   # interval = 100
        [(0, 90), (270, 370), (570, 670), (870, 970), (1170, 1270)], # interval = 200 
        [(0, 90), (480, 580), (980, 1080), (1480, 1580), (1980, 2080)], # interval = 400 
        [(0, 90), (180, 280), (380, 480), (580, 680), (780, 880)],   # interval = 100 
        [(0, 90), (270, 370), (570, 670), (870, 970), (1170, 1270)], # interval = 200 
        [(0, 90), (480, 580), (980, 1080), (1480, 1580), (1980, 2080)] # interval = 400
    ]
    
    save_directory = "./actr/figures"

    missing_files = [f for f in dat_files if not os.path.exists(f)]
    if missing_files:
        print(f"warning: file is missing: {missing_files}")
    else:
        optimal_params, all_metrics, overall_metrics = main(dat_files, review_intervals_list, save_directory)