import os
import json
import math
import pickle
import warnings
import torch
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, accuracy_score, f1_score, roc_auc_score, precision_score, recall_score
from scipy.stats import pearsonr, spearmanr
 

def load_model(args, model):
    model.load_state_dict(torch.load(args.model_path, map_location=args.device))
    deltas = pickle.load(open(args.train_deltas_path))
    return model, deltas

def skew_batch_data(batch_x, batch_y, skew_direction, device=None):
    t1_idx = torch.arange(batch_x.size(0), device=device)
    t2_idx = torch.randint(0, batch_x.size(0), (batch_x.size(0),), device=device)
    
    if skew_direction == 'right':
        swap = (batch_y[t1_idx] > batch_y[t2_idx]).squeeze()
        temp_x = batch_x[t1_idx][swap].clone()
        batch_x[t1_idx][swap] = batch_x[t2_idx][swap]
        batch_x[t2_idx][swap] = temp_x
        
        temp_y = batch_y[t1_idx][swap].clone()
        batch_y[t1_idx][swap] = batch_y[t2_idx][swap]
        batch_y[t2_idx][swap] = temp_y
    elif skew_direction == 'left':
        swap = (batch_y[t1_idx] < batch_y[t2_idx]).squeeze()
        temp_x = batch_x[t1_idx][swap].clone()
        batch_x[t1_idx][swap] = batch_x[t2_idx][swap]
        batch_x[t2_idx][swap] = temp_x
        
        temp_y = batch_y[t1_idx][swap].clone()
        batch_y[t1_idx][swap] = batch_y[t2_idx][swap]
        batch_y[t2_idx][swap] = temp_y
    
    return batch_x, batch_y, t2_idx
            

class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.avg, self.sum, self.count = [0] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count

    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text
    
    
def calculate_rmse(y_true, y_pred):
    """
    Calculate Root Mean Squared Error
    
    Args:
        y_true: Array-like of true values
        y_pred: Array-like of predicted values
    
    Returns:
        RMSE value
    """
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()
        
    return np.sqrt(mean_squared_error(y_true, y_pred))

def calculate_mae(y_true, y_pred):
    """
    Calculate Mean Absolute Error
    
    Args:
        y_true: Array-like of true values
        y_pred: Array-like of predicted values
    
    Returns:
        MAE value
    """
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()
    
    return mean_absolute_error(y_true, y_pred)

def calculate_r2(y_true, y_pred):
    """
    Calculate R² (coefficient of determination)
    
    Args:
        y_true: Array-like of true values
        y_pred: Array-like of predicted values
    
    Returns:
        R² value
    """
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()
    
    return r2_score(y_true, y_pred)

def calculate_pearson(y_true, y_pred):
    """
    Calculate Pearson correlation coefficient
    
    Args:
        y_true: Array-like of true values
        y_pred: Array-like of predicted values
    
    Returns:
        Pearson correlation coefficient and p-value
    """
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()
    
    # Reshape if needed
    if len(y_true.shape) > 1:
        y_true = y_true.reshape(-1)
    if len(y_pred.shape) > 1:
        y_pred = y_pred.reshape(-1)
    
    return pearsonr(y_true, y_pred)[0]  # Return only the correlation coefficient

def calculate_spearman(y_true, y_pred):
    """
    Calculate Spearman rank correlation coefficient
    
    Args:
        y_true: Array-like of true values
        y_pred: Array-like of predicted values
    
    Returns:
        Spearman rank correlation coefficient and p-value
    """
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()
    
    # Reshape if needed
    if len(y_true.shape) > 1:
        y_true = y_true.reshape(-1)
    if len(y_pred.shape) > 1:
        y_pred = y_pred.reshape(-1)
    
    return spearmanr(y_true, y_pred)[0]  # Return only the correlation coefficient



def calculate_metrics(y_true, y_pred, prefix="", task='regression'):
    """
    Evaluate model performance with multiple metrics

    Args:
        y_true: Array-like of true values
        y_pred: Array-like of predicted values or logits
        prefix: Optional prefix for metric names
        task: 'regression' or 'classification'

    Returns:
        Dictionary of metrics
    """
    # Convert tensors to numpy
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()

    # Reshape if needed
    y_true = y_true.reshape(-1)
    y_pred = y_pred.reshape(-1)

    metrics = {}

    if task == 'classification':
        # Predicted probabilities assumed, threshold at 0.5
        y_prob = y_pred

        y_true = (y_true > 0.5).astype(int)
        y_label = (y_prob > 0.5).astype(int)

        metrics[f"{prefix}_accuracy"] = accuracy_score(y_true, y_label)
        metrics[f"{prefix}_f1"] = f1_score(y_true, y_label)
        metrics[f"{prefix}_precision"] = precision_score(y_true, y_label)
        metrics[f"{prefix}_recall"] = recall_score(y_true, y_label)

        # AUROC requires both probs and true labels
        try:
            metrics[f"{prefix}_auroc"] = roc_auc_score(y_true, y_prob)
        except ValueError:
            mask = ~np.isnan(y_prob)
            print(f"CAUTION: {(mask==0).sum()} values are masked ")
            metrics[f"{prefix}_auroc"] = roc_auc_score(y_true[mask], y_prob[mask])
    else:
        metrics[f"{prefix}_rmse"] = calculate_rmse(y_true, y_pred)
        metrics[f"{prefix}_mae"] = calculate_mae(y_true, y_pred)
        metrics[f"{prefix}_r2"] = calculate_r2(y_true, y_pred)
        metrics[f"{prefix}_pearson"] = calculate_pearson(y_true, y_pred)
        metrics[f"{prefix}_spearman"] = calculate_spearman(y_true, y_pred)

    return metrics


def save_results(args, results, prefix):
    results_path = os.path.join(args.save_path, f'{prefix}_results.json')
    results = {k: float(v) for k, v in results.items()}
    print(results)
    json.dump(results, open(results_path, 'w'), indent=4)



def update_temperature(cfg, epoch, initial_temperature, anneal_epochs, final_temperature, initial_temp_calc):
    if epoch < anneal_epochs:
        progress = epoch / (anneal_epochs -1) if anneal_epochs > 1 else 1.0 # Progress from 0 to 1

        # Calculate temperature based on schedule type
        if cfg.annealing.schedule_type == 'linear':
            current_temperature = initial_temperature - progress * (initial_temperature - final_temperature)
        elif cfg.annealing.schedule_type == 'exponential':
            # Interpolate exponentially in log space for stability
            if initial_temp_calc > 1e-6 and final_temperature > 1e-6: # Avoid log(0)
                log_T_initial = math.log(initial_temp_calc)
                log_T_final = math.log(final_temperature)
                log_T_current = log_T_initial - progress * (log_T_initial - log_T_final)
                current_temperature = math.exp(log_T_current)
            else: # Fallback to linear if temps are near zero
                current_temperature = initial_temperature - progress * (initial_temperature - final_temperature)
                # Optional: Print warning only once
                # if epoch == 0: print("Warning: Using linear fallback for exponential anneal due to zero initial/final temp.")

        elif cfg.annealing.schedule_type == 'cosine':
            # Cosine schedule from T_initial down to T_final
            current_temperature = final_temperature + 0.5 * (initial_temperature - final_temperature) * (1 + math.cos(math.pi * progress))
        else:
            # Default or fallback to linear if type is unknown
            if epoch == 0: warnings.warn(f"Unknown anneal_schedule_type: {cfg.annealing.schedule_type}. Using linear.")
            current_temperature = initial_temperature - progress * (initial_temperature - final_temperature)

        # Clamp temperature to ensure it doesn't overshoot final_temperature
        if initial_temperature >= final_temperature: # Decreasing temp
            current_temperature = max(current_temperature, final_temperature)
        else: # Increasing temp (less common)
            current_temperature = min(current_temperature, final_temperature)

        # transducer.temperature = current_temperature # Update transducer's temperature

    elif epoch >= anneal_epochs:
        # If annealing period is over, keep using the final temperature
        current_temperature = final_temperature
        # transducer.temperature = current_temperature
        
    return current_temperature
    