import os
import json
import random

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import wandb

from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, accuracy_score, f1_score, roc_auc_score, precision_score, recall_score
from scipy.stats import pearsonr, spearmanr


class ConfigNamespace:
    def __init__(self, d):
        for k, v in d.items():
            if isinstance(v, dict):
                setattr(self, k, ConfigNamespace(v))
            else:
                setattr(self, k, v)
                
                
def set_wandb(args, config_dict):
    prop_type = args.prop_type
    seed = args.seed
    date_str = args.date_str
    if '_x' in prop_type:
        prop_type = prop_type.replace('_x', '')
        name=f"{prop_type}_x_{args.model_name}_{seed}_{date_str}"
    else:
        name=f"{prop_type}_y_{args.model_name}_{seed}_{date_str}"
        
    wandb.init(
        project=args.proj_name,
        config={**vars(args), **config_dict},
        name = name
    )

def set_seed(args):
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        cudnn.deterministic = True
        cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)


    
def calculate_rmse(y_true, y_pred):
    """
    Calculate Root Mean Squared Error
    
    Args:
        y_true: Array-like of true values
        y_pred: Array-like of predicted values
    
    Returns:
        RMSE value
    """
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()
        
    return np.sqrt(mean_squared_error(y_true, y_pred))

def calculate_mae(y_true, y_pred):
    """
    Calculate Mean Absolute Error
    
    Args:
        y_true: Array-like of true values
        y_pred: Array-like of predicted values
    
    Returns:
        MAE value
    """
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()
    
    return mean_absolute_error(y_true, y_pred)

def calculate_r2(y_true, y_pred):
    """
    Calculate R² (coefficient of determination)
    
    Args:
        y_true: Array-like of true values
        y_pred: Array-like of predicted values
    
    Returns:
        R² value
    """
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()
    
    return r2_score(y_true, y_pred)

def calculate_pearson(y_true, y_pred):
    """
    Calculate Pearson correlation coefficient
    
    Args:
        y_true: Array-like of true values
        y_pred: Array-like of predicted values
    
    Returns:
        Pearson correlation coefficient and p-value
    """
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()
    
    # Reshape if needed
    if len(y_true.shape) > 1:
        y_true = y_true.reshape(-1)
    if len(y_pred.shape) > 1:
        y_pred = y_pred.reshape(-1)
    
    return pearsonr(y_true, y_pred)[0]  # Return only the correlation coefficient

def calculate_spearman(y_true, y_pred):
    """
    Calculate Spearman rank correlation coefficient
    
    Args:
        y_true: Array-like of true values
        y_pred: Array-like of predicted values
    
    Returns:
        Spearman rank correlation coefficient and p-value
    """
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()
    
    # Reshape if needed
    if len(y_true.shape) > 1:
        y_true = y_true.reshape(-1)
    if len(y_pred.shape) > 1:
        y_pred = y_pred.reshape(-1)
    
    return spearmanr(y_true, y_pred)[0]  # Return only the correlation coefficient

def calculate_metrics(y_true, y_pred, prefix="", task='regression'):
    """
    Evaluate model performance with multiple metrics

    Args:
        y_true: Array-like of true values
        y_pred: Array-like of predicted values or logits
        prefix: Optional prefix for metric names
        task: 'regression' or 'classification'

    Returns:
        Dictionary of metrics
    """
    # Convert tensors to numpy
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()

    # Reshape if needed
    y_true = y_true.reshape(-1)
    y_pred = y_pred.reshape(-1)

    metrics = {}

    if task == 'classification':
        # Predicted probabilities assumed, threshold at 0.5
        y_prob = y_pred
        y_true = (y_true > 0.5).astype(int)
        y_label = (y_prob > 0.5).astype(int)

        metrics[f"{prefix}_accuracy"] = accuracy_score(y_true, y_label)
        metrics[f"{prefix}_f1"] = f1_score(y_true, y_label)
        metrics[f"{prefix}_precision"] = precision_score(y_true, y_label)
        metrics[f"{prefix}_recall"] = recall_score(y_true, y_label)

        # AUROC requires both probs and true labels
        try:
            metrics[f"{prefix}_auroc"] = roc_auc_score(y_true, y_prob)
        except ValueError:
            mask = ~np.isnan(y_prob)
            print(f"CAUTION: {(mask==0).sum()} values are masked ")
            metrics[f"{prefix}_auroc"] = roc_auc_score(y_true[mask], y_prob[mask])
    else:
        metrics[f"{prefix}_rmse"] = calculate_rmse(y_true, y_pred)
        metrics[f"{prefix}_mae"] = calculate_mae(y_true, y_pred)
        metrics[f"{prefix}_r2"] = calculate_r2(y_true, y_pred)
        metrics[f"{prefix}_pearson"] = calculate_pearson(y_true, y_pred)
        metrics[f"{prefix}_spearman"] = calculate_spearman(y_true, y_pred)

    return metrics

def save_results(args, results, prefix):
    results_path = os.path.join(args.save_path, f'{prefix}_results.json')
    results = {k: float(v) for k, v in results.items()}
    print(results)
    json.dump(results, open(results_path, 'w'), indent=4)