import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from scipy.optimize import curve_fit

import torch
from torch.utils.data import DataLoader

# Compliance Rate Calculation Functions
def compute_compliance_rate_with_trend_verification(predictions, action, delta=0, params=None):
    """
    Compute the compliance rate for trend-based actions with trend verification.

    Args:
        predictions (np.ndarray): Model predictions of shape (batch_size, T).
        action (str): The action to evaluate.
        delta (float): Threshold for slope or rate parameter.
        params (dict, optional): Additional parameters (e.g., 't_transition').

    Returns:
        float: Compliance rate as a percentage.
    """
    batch_size, T = predictions.shape
    total_series = batch_size
    compliance_count = 0

    time_indices = np.arange(1, T + 1)

    for i in range(batch_size):
        y = predictions[i]
        t = time_indices

        # Prepare dictionaries to store models and R2 scores
        models = {}
        r2_scores = {}

        # Linear Model
        linear_model = LinearRegression()
        linear_model.fit(t.reshape(-1, 1), y)
        y_linear_pred = linear_model.predict(t.reshape(-1, 1))
        linear_r2 = r2_score(y, y_linear_pred)
        models['linear'] = {'model': linear_model, 'pred': y_linear_pred}
        r2_scores['linear'] = linear_r2

        # Exponential Model
        epsilon = 1e-8
        y_safe = y + epsilon  # Avoid log(0)
        def exp_func(t, a, b):
            return a * np.exp(b * t)
        try:
            popt_exp, _ = curve_fit(exp_func, t, y_safe, maxfev=10000)
            y_exp_pred = exp_func(t, *popt_exp)
            exp_r2 = r2_score(y_safe, y_exp_pred)
            models['exponential'] = {'params': popt_exp, 'pred': y_exp_pred}
            r2_scores['exponential'] = exp_r2
        except RuntimeError:
            exp_r2 = -np.inf  # Indicate poor fit

        # Logarithmic Model
        t_log = np.log(t)
        def log_func(t_log, a, c):
            return a * t_log + c
        try:
            popt_log, _ = curve_fit(log_func, t_log, y)
            y_log_pred = log_func(t_log, *popt_log)
            log_r2 = r2_score(y, y_log_pred)
            models['logarithmic'] = {'params': popt_log, 'pred': y_log_pred}
            r2_scores['logarithmic'] = log_r2
        except RuntimeError:
            log_r2 = -np.inf  # Indicate poor fit

        # Determine Best Fit
        best_fit = max(r2_scores, key=r2_scores.get)
        compliant = False

        if action == 'Linear Trend Up':
            if best_fit == 'linear':
                slope = linear_model.coef_[0]
                if slope >= delta:
                    compliant = True

        elif action == 'Linear Trend Down':
            if best_fit == 'linear':
                slope = linear_model.coef_[0]
                if slope <= -delta:
                    compliant = True

        elif action == 'Exponential Growth':
            if best_fit == 'exponential':
                b = models['exponential']['params'][1]
                if b >= delta:
                    compliant = True

        elif action == 'Exponential Decay':
            if best_fit == 'exponential':
                b = models['exponential']['params'][1]
                if b <= -delta:
                    compliant = True

        elif action == 'Logarithmic Growth':
            if best_fit == 'logarithmic':
                a = models['logarithmic']['params'][0]
                if a >= delta:
                    compliant = True

        elif action == 'Logarithmic Decay':
            if best_fit == 'logarithmic':
                a = models['logarithmic']['params'][0]
                if a <= -delta:
                    compliant = True

        elif action == 'Keep Stable':
            # For 'Keep Stable', we can use standard deviation
            std_dev = np.std(y)
            if std_dev <= delta:  # Here, delta represents epsilon (acceptable fluctuation)
                compliant = True

        elif action in ['Linear Growth and Linear Decay', 'Linear Decay and Linear Growth']:
            if params is None or 't_transition' not in params:
                raise ValueError("Parameter 't_transition' is required for this action.")
            t_transition = params['t_transition']

            # Split into segments
            t1 = t[:t_transition]
            y1 = y[:t_transition]
            t2 = t[t_transition:]
            y2 = y[t_transition:]

            # First Segment
            linear_model1 = LinearRegression()
            linear_model1.fit(t1.reshape(-1, 1), y1)
            y_linear_pred1 = linear_model1.predict(t1.reshape(-1, 1))
            r2_score1 = r2_score(y1, y_linear_pred1)

            # Second Segment
            linear_model2 = LinearRegression()
            linear_model2.fit(t2.reshape(-1, 1), y2)
            y_linear_pred2 = linear_model2.predict(t2.reshape(-1, 1))
            r2_score2 = r2_score(y2, y_linear_pred2)

            # Compliance Conditions
            if action == 'Linear Growth and Linear Decay':
                if r2_score1 >= r2_scores.get('exponential', -np.inf) and r2_score2 >= r2_scores.get('exponential', -np.inf):
                    m1 = linear_model1.coef_[0]
                    m2 = linear_model2.coef_[0]
                    if m1 >= delta and m2 <= -delta:
                        compliant = True
            elif action == 'Linear Decay and Linear Growth':
                if r2_score1 >= r2_scores.get('exponential', -np.inf) and r2_score2 >= r2_scores.get('exponential', -np.inf):
                    m1 = linear_model1.coef_[0]
                    m2 = linear_model2.coef_[0]
                    if m1 <= -delta and m2 >= delta:
                        compliant = True

        else:
            raise ValueError(f"Action '{action}' is not recognized.")

        if compliant:
            compliance_count += 1

    compliance_rate = (compliance_count / total_series) * 100
    return compliance_rate

def compute_compliance_rate_amplitude(predictions_with_action, baseline_predictions, action, A, epsilon=1e-5):
    """
    Compute the compliance rate for amplitude adjustment actions over a series of length T.

    Args:
        predictions_with_action (np.ndarray): Predictions with the action applied, shape (batch_size, T).
        baseline_predictions (np.ndarray): Baseline predictions without the action, same shape.
        action (str): 'Increase Amplitude' or 'Decrease Amplitude'.
        A (float): Scaling factor for amplitude adjustment (A > 0).
        epsilon (float, optional): Acceptable deviation threshold.

    Returns:
        float: Compliance rate as a percentage.
    """
    if action == 'Increase Amplitude':
        expected = baseline_predictions * (1 + A)
    elif action == 'Decrease Amplitude':
        expected = baseline_predictions * (1 - A)
    else:
        raise ValueError("Action must be 'Increase Amplitude' or 'Decrease Amplitude'.")

    # Compute the absolute difference
    abs_diff = np.abs(predictions_with_action - expected)

    # Check compliance at each time step t = 1 to T
    compliance = abs_diff <= epsilon

    # Compute compliance rate
    total_elements = compliance.size
    compliance_count = np.sum(compliance)
    compliance_rate = (compliance_count / total_elements) * 100

    return compliance_rate

# Example Usage with Data Iterators
def calculate_compliance_rate(model, data_iterator, action, delta=0, params=None, A=None, epsilon=1e-5, device='cpu'):
    """
    Calculate the compliance rate over the dataset provided by the data iterator.

    Args:
        model: The trained model for making predictions.
        data_iterator: An iterator that yields batches of data in the form (inputs, targets).
        action (str): The action to evaluate.
        delta (float, optional): Threshold for slope or rate parameter.
        params (dict, optional): Additional parameters (e.g., 't_transition').
        A (float, optional): Scaling factor for amplitude adjustment.
        epsilon (float, optional): Acceptable deviation threshold.
        device (str, optional): Device to run the computations on ('cpu' or 'cuda').

    Returns:
        float: Compliance rate as a percentage.
    """
    model.eval()  # Set model to evaluation mode
    all_predictions = []
    all_baseline_predictions = []

    with torch.no_grad():
        for batch in data_iterator:
            # Assuming batch is a tuple (inputs, targets)
            inputs, _ = batch  # We only need inputs for predictions
            inputs = inputs.to(device)

            # Get predictions with action
            predictions_with_action = model(inputs)
            predictions_with_action = predictions_with_action.cpu().numpy()

            # For amplitude adjustments, we need baseline predictions
            if action in ['Increase Amplitude', 'Decrease Amplitude']:
                # Generate baseline predictions (without action)
                # For demonstration, we'll assume the baseline is zeros or can be obtained similarly
                # In practice, you might need to run the model without applying the action
                baseline_predictions = get_baseline_predictions(inputs, model, device)
                baseline_predictions = baseline_predictions.cpu().numpy()
                all_baseline_predictions.append(baseline_predictions)

            all_predictions.append(predictions_with_action)

    # Concatenate all predictions
    all_predictions = np.concatenate(all_predictions, axis=0)  # Shape: (total_samples, T)

    if action in ['Increase Amplitude', 'Decrease Amplitude']:
        all_baseline_predictions = np.concatenate(all_baseline_predictions, axis=0)  # Shape: (total_samples, T)
        compliance_rate = compute_compliance_rate_amplitude(
            all_predictions, all_baseline_predictions, action, A, epsilon)
    else:
        compliance_rate = compute_compliance_rate_with_trend_verification(
            all_predictions, action, delta, params)

    return compliance_rate

