import os
import sys
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
import argparse
import configparser
from pathlib import Path
import logging
from datetime import datetime
import json
from tqdm import tqdm

sys.path.append('./')
from Orion_model import create_orion_model

def setup_logging(network_name):
    log_dir = Path(f'./DREAM3_auc_results')
    log_dir.mkdir(parents=True, exist_ok=True)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = log_dir / f"auc_test_{network_name}_{timestamp}.log"
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler(sys.stdout)
        ]
    )
    
    return logging.getLogger(__name__)

def load_dream3_data(network_name):
    data_dir = Path(f'./DREAM3_Orion/{network_name}')
    
    npz_file = data_dir / f"{network_name}_r1_d0_w0_Orion.npz"
    
    if not npz_file.exists():
        npz_file = data_dir / f"{network_name}.npz"
    
    data = np.load(npz_file)
    
    gold_file = data_dir / f"{network_name}_gold_standard.npy"
    gold_standard = np.load(gold_file)
    
    return data, gold_standard

def load_trained_orion_model(config_path, model_path, device):
    config = configparser.ConfigParser()
    config.read(config_path)
    
    model_config = {
        'num_of_vertices': int(config['Data']['num_of_vertices']),
        'in_channels': int(config['Data']['in_channels']),
        'target_len': int(config['Data']['target_len']),
        'source_len': int(config['Data']['source_len']),
        'd_model': int(config['Model']['d_model']),
        'd_ff_emb': int(config['Model']['d_ff_emb']),
        'd_ff_belt': int(config['Model']['d_ff_belt']),
        'd_ff_fusion': int(config['Model']['d_ff_fusion']),
        'd_ff_reverse': int(config['Model']['d_ff_reverse']),
        'n_belt_block': int(config['Model']['n_belt_block']),
        'head_s': int(config['Model']['head_s']),
        'head_t': int(config['Model']['head_t']),
        'head_f': int(config['Model']['head_f']),
        'num_time_segments': int(config['Model']['num_time_segments']),
        'dropout': float(config['Model']['dropout']),
        'causal_threshold': float(config['Model']['causal_threshold'])
    }
    
    if 'node_selection_ratio' in config['Model']:
        model_config['node_selection_ratio'] = float(config['Model']['node_selection_ratio'])
    
    model = create_orion_model(model_config)
    
    if os.path.exists(model_path):
        checkpoint = torch.load(model_path, map_location=device, weights_only=True)
        
        if 'model_state_dict' in checkpoint:
            state_dict = checkpoint['model_state_dict']
        else:
            state_dict = checkpoint
        
        if 'module.' in list(state_dict.keys())[0]:
            new_state_dict = {}
            for k, v in state_dict.items():
                new_state_dict[k.replace('module.', '')] = v
            state_dict = new_state_dict
        
        model.load_state_dict(state_dict, strict=False)
        logger.info(f"Successfully loaded model weights: {model_path}")
    else:
        logger.warning(f"Model file not found: {model_path}, using randomly initialized model")
    
    model = model.to(device)
    model.eval()
    
    return model

def extract_causal_matrix_from_orion(model, data, device, logger):
    logger.info("Starting Orion causal matrix extraction...")
    
    if 'test_x_h' not in data.files:
        logger.error("test_x_h not found in data")
        return None
    
    test_x_h = torch.from_numpy(data['test_x_h']).float().to(device)
    test_x_w = torch.from_numpy(data['test_x_w']).float().to(device) if 'test_x_w' in data.files else test_x_h
    test_x_d = torch.from_numpy(data['test_x_d']).float().to(device) if 'test_x_d' in data.files else test_x_h
    test_time_indices = torch.from_numpy(data['test_time_indices']).long().to(device) if 'test_time_indices' in data.files else None
    
    batch_size = test_x_h.shape[0]
    num_nodes = test_x_h.shape[1]
    
    logger.info(f"Data dimensions: batch_size={batch_size}, nodes={num_nodes}")
    
    if test_x_w.dim() == 5 and test_x_w.shape[0] != batch_size:
        test_x_w = test_x_w.permute(1, 0, 2, 3, 4)
    
    if test_x_d.dim() == 5 and test_x_d.shape[0] != batch_size:
        test_x_d = test_x_d.permute(1, 0, 2, 3, 4)
    
    if test_time_indices is None:
        target_len = test_x_h.shape[-1]
        test_time_indices = torch.arange(target_len).unsqueeze(0).repeat(batch_size, 1).to(device)
    
    all_batch_causal_matrices = []
    batch_auc_scores = []
    best_batch_auc = -1
    best_batch_idx = -1
    best_batch_causal = None
    
    gold_file = Path(f'./DREAM3_Orion/{logger.name.split("_")[-1]}/{logger.name.split("_")[-1]}_gold_standard.npy')
    if gold_file.exists():
        gold_standard = np.load(gold_file)
        calculate_batch_auc = True
    else:
        calculate_batch_auc = False
        logger.warning("Gold Standard not found, cannot calculate batch-level AUC")
    
    logger.info(f"Processing {batch_size} batches...")
    
    with torch.no_grad():
        for batch_idx in tqdm(range(batch_size), desc="Extracting causal matrices"):
            try:
                batch_x_h = test_x_h[batch_idx:batch_idx+1]
                batch_x_w = test_x_w[batch_idx:batch_idx+1]
                batch_x_d = test_x_d[batch_idx:batch_idx+1]
                batch_time = test_time_indices[batch_idx:batch_idx+1]
                
                target_len = batch_x_h.shape[-1]
                if batch_time.shape[1] != target_len:
                    if batch_time.shape[1] > target_len:
                        batch_time = batch_time[:, :target_len]
                    else:
                        padding_length = target_len - batch_time.shape[1]
                        padding = batch_time[:, -1:].expand(-1, padding_length)
                        batch_time = torch.cat([batch_time, padding], dim=1)
                
                outputs = model(batch_x_h, batch_x_w, batch_x_d, batch_time, 
                              perform_intervention=False)
                
                if 'causal_matrices' in outputs and outputs['causal_matrices'] is not None:
                    causal_matrices = outputs['causal_matrices']
                    
                    if isinstance(causal_matrices, list):
                        for cm in causal_matrices:
                            if cm is not None and isinstance(cm, torch.Tensor):
                                cm_np = cm.detach().cpu().numpy()
                                
                                if cm_np.ndim == 4:
                                    cm_np = cm_np[0]
                                    if cm_np.ndim == 3:
                                        batch_causal = cm_np.mean(axis=0)
                                    else:
                                        batch_causal = cm_np
                                elif cm_np.ndim == 3:
                                    batch_causal = cm_np.mean(axis=0)
                                elif cm_np.ndim == 2:
                                    batch_causal = cm_np
                                else:
                                    continue
                                
                                np.fill_diagonal(batch_causal, 0)
                                
                                all_batch_causal_matrices.append(batch_causal)
                                
                                if calculate_batch_auc and batch_causal.shape == gold_standard.shape:
                                    batch_auc = calculate_batch_auc_score(batch_causal, gold_standard)
                                    batch_auc_scores.append(batch_auc)
                                    
                                    if batch_auc > best_batch_auc:
                                        best_batch_auc = batch_auc
                                        best_batch_idx = batch_idx
                                        best_batch_causal = batch_causal.copy()
                                    
                                    if (batch_idx + 1) % 10 == 0:
                                        logger.info(f"Batch {batch_idx+1}: AUC = {batch_auc:.4f}")
                                
                                break
                    
                    elif isinstance(causal_matrices, torch.Tensor):
                        cm_np = causal_matrices.detach().cpu().numpy()
                        
                        while cm_np.ndim > 2:
                            cm_np = cm_np.mean(axis=0)
                        
                        np.fill_diagonal(cm_np, 0)
                        all_batch_causal_matrices.append(cm_np)
                        
                        if calculate_batch_auc and cm_np.shape == gold_standard.shape:
                            batch_auc = calculate_batch_auc_score(cm_np, gold_standard)
                            batch_auc_scores.append(batch_auc)
                            
                            if batch_auc > best_batch_auc:
                                best_batch_auc = batch_auc
                                best_batch_idx = batch_idx
                                best_batch_causal = cm_np.copy()
                
            except Exception as e:
                logger.warning(f"Batch {batch_idx} processing failed: {e}")
                continue
    
    if batch_auc_scores:
        logger.info("Batch-level AUC statistics:")
        logger.info(f"Processed batches: {len(batch_auc_scores)}")
        logger.info(f"Average AUC: {np.mean(batch_auc_scores):.4f}")
        logger.info(f"Std deviation: {np.std(batch_auc_scores):.4f}")
        logger.info(f"Min AUC: {np.min(batch_auc_scores):.4f}")
        logger.info(f"Max AUC: {np.max(batch_auc_scores):.4f}")
        logger.info(f"Best batch index: {best_batch_idx}, Best AUC: {best_batch_auc:.4f}")
    
    if all_batch_causal_matrices:
        final_causal_matrix = np.mean(all_batch_causal_matrices, axis=0)
        
        logger.info(f"Successfully extracted causal matrices from {len(all_batch_causal_matrices)} batches")
        logger.info(f"Final causal matrix shape: {final_causal_matrix.shape}")
        logger.info(f"Value range: [{final_causal_matrix.min():.6f}, {final_causal_matrix.max():.6f}]")
        logger.info(f"Non-zero elements: {np.sum(final_causal_matrix != 0)} / {final_causal_matrix.size}")
        
        if best_batch_causal is not None:
            save_dir = Path('./DREAM3_auc_results')
            save_dir.mkdir(parents=True, exist_ok=True)
            np.save(save_dir / f'best_batch_causal_{logger.name.split("_")[-1]}.npy', best_batch_causal)
            logger.info(f"Best batch causal matrix saved")
        
        return final_causal_matrix
        
    else:
        logger.error("Failed to extract valid causal matrix from any batch")
        return None

def calculate_batch_auc_score(causal_matrix, gold_standard):
    n = causal_matrix.shape[0]
    mask = np.ones((n, n), dtype=bool)
    np.fill_diagonal(mask, False)
    
    pred_flat = causal_matrix[mask]
    true_flat = gold_standard[mask]
    
    n_positive = np.sum(true_flat)
    n_negative = len(true_flat) - n_positive
    
    if n_positive == 0 or n_negative == 0:
        return 0.5
    
    try:
        auc_score = roc_auc_score(true_flat, pred_flat)
        return auc_score
    except:
        return 0.5

def calculate_auc_metrics(predicted_matrix, gold_standard, logger):
    if predicted_matrix is None:
        logger.error("Predicted matrix is None")
        return {'auc': 0.0, 'aupr': 0.0}
    
    if predicted_matrix.shape != gold_standard.shape:
        logger.warning(f"Dimension mismatch: predicted={predicted_matrix.shape}, gold={gold_standard.shape}")
        min_size = min(predicted_matrix.shape[0], gold_standard.shape[0])
        predicted_matrix = predicted_matrix[:min_size, :min_size]
        gold_standard = gold_standard[:min_size, :min_size]
    
    n = predicted_matrix.shape[0]
    mask = np.ones((n, n), dtype=bool)
    np.fill_diagonal(mask, False)
    
    pred_flat = predicted_matrix[mask]
    true_flat = gold_standard[mask]
    
    n_positive = np.sum(true_flat)
    n_negative = len(true_flat) - n_positive
    logger.info(f"Positive samples: {n_positive}, Negative samples: {n_negative}")
    logger.info(f"Positive ratio: {n_positive/len(true_flat):.4f}")
    
    if n_positive == 0 or n_negative == 0:
        logger.warning("Gold Standard all zeros or all ones, cannot calculate valid AUC")
        return {'auc': 0.5, 'aupr': n_positive / len(true_flat)}
    
    try:
        auc_score = roc_auc_score(true_flat, pred_flat)
    except Exception as e:
        logger.warning(f"AUC calculation failed: {e}")
        auc_score = 0.5
    
    try:
        precision, recall, _ = precision_recall_curve(true_flat, pred_flat)
        aupr_score = auc(recall, precision)
    except Exception as e:
        logger.warning(f"AUPR calculation failed: {e}")
        aupr_score = n_positive / len(true_flat)
    
    logger.info(f"Final metrics: AUC={auc_score:.4f}, AUPR={aupr_score:.4f}")
    
    return {
        'auc': auc_score,
        'aupr': aupr_score
    }

def test_single_network(network_name, model_path=None):
    global logger
    logger = setup_logging(network_name)
    
    logger.info(f"Testing network: {network_name}")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Using device: {device}")
    
    try:
        logger.info("Loading data...")
        data, gold_standard = load_dream3_data(network_name)
        logger.info(f"NPZ file keys: {list(data.files)}")
        logger.info(f"Gold Standard shape: {gold_standard.shape}")
        
        logger.info("Loading Orion model...")
        config_path = f'./DREAM3_configs/{network_name}_config.conf'
        if model_path is None:
            possible_paths = [
                f'./DREAM3_models/{network_name}/best_model_mae.pth',
                f'./DREAM3_models/{network_name}/best_model.pth'
            ]
            for path in possible_paths:
                if os.path.exists(path):
                    model_path = path
                    break
            
            if model_path is None:
                model_path = possible_paths[0]
        
        model = load_trained_orion_model(config_path, model_path, device)
        
        logger.info("Extracting Orion causal matrix...")
        causal_matrix = extract_causal_matrix_from_orion(model, data, device, logger)
        
        if causal_matrix is None:
            logger.error("Causal matrix extraction failed")
            return {'auc': 0.0, 'aupr': 0.0}
        
        logger.info("Calculating evaluation metrics...")
        metrics = calculate_auc_metrics(causal_matrix, gold_standard, logger)
        
        return metrics
        
    except Exception as e:
        logger.error(f"Test failed: {e}")
        import traceback
        traceback.print_exc()
        return {'auc': 0.0, 'aupr': 0.0}

def test_all_networks():
    networks = ['Ecoli1', 'Ecoli2', 'Yeast1', 'Yeast2', 'Yeast3']
    all_results = {}
    
    print("\nDREAM3 In Silico Network Challenge - Orion AUC Evaluation")
    print("-" * 60)
    
    for network in networks:
        print(f"\nTesting {network}...")
        try:
            metrics = test_single_network(network)
            all_results[network] = metrics
            print(f"  AUC: {metrics['auc']:.4f}, AUPR: {metrics['aupr']:.4f}")
        except Exception as e:
            print(f"  Test failed: {e}")
            all_results[network] = {'auc': 0.0, 'aupr': 0.0}
    
    valid_results = [r for r in all_results.values() if r['auc'] > 0]
    if valid_results:
        avg_auc = np.mean([r['auc'] for r in valid_results])
        avg_aupr = np.mean([r['aupr'] for r in valid_results])
    else:
        avg_auc = 0.0
        avg_aupr = 0.0
    
    print("\nTest Results Summary")
    print("-" * 60)
    
    print("\nIndividual network performance:")
    for network, metrics in all_results.items():
        print(f"  {network}: AUC={metrics['auc']:.4f}, AUPR={metrics['aupr']:.4f}")
    
    print(f"\nAverage performance:")
    print(f"  Average AUC: {avg_auc:.4f}")
    print(f"  Average AUPR: {avg_aupr:.4f}")
    
    print(f"\nComparison with Casper:")
    print(f"  Casper reported AUC: 0.6325")
    print(f"  Orion AUC: {avg_auc:.4f}")
    
    if avg_auc > 0.6325:
        print(f"  Orion performs better than Casper (+{(avg_auc-0.6325):.4f})")
    else:
        print(f"  Performance gap: {(0.6325-avg_auc):.4f}")
    
    results_dir = Path('./DREAM3_auc_results')
    results_dir.mkdir(parents=True, exist_ok=True)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_file = results_dir / f'auc_results_{timestamp}.json'
    
    with open(results_file, 'w') as f:
        json.dump({
            'individual_results': all_results,
            'average_auc': avg_auc,
            'average_aupr': avg_aupr,
            'timestamp': timestamp
        }, f, indent=2)
    
    print(f"\nResults saved to: {results_file}")
    
    return all_results

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='DREAM3 AUC Evaluation')
    parser.add_argument('--network', type=str, default='all',
                       help='Network name (Ecoli1/Ecoli2/Yeast1/Yeast2/Yeast3/all)')
    parser.add_argument('--model_path', type=str, default=None,
                       help='Model weights file path')
    
    args = parser.parse_args()
    
    if args.network == 'all':
        test_all_networks()
    else:
        metrics = test_single_network(args.network, args.model_path)
        print(f"\n{args.network} Results:")
        print(f"  AUC: {metrics['auc']:.4f}")
        print(f"  AUPR: {metrics['aupr']:.4f}")