import os
import sys
import numpy as np
import argparse
import configparser
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import time
import logging
from tqdm import tqdm
import warnings
from datetime import datetime
from pathlib import Path

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
warnings.filterwarnings('ignore')

try:
    from Orion_model import OrionModel, create_orion_model
except ImportError:
    print("Error: Cannot import Orion_model.py")
    sys.exit(1)

class OrionDataset(torch.utils.data.Dataset):
    def __init__(self, x_h, x_w, x_d, target, time_indices):
        self.x_h = x_h
        self.x_w = x_w  
        self.x_d = x_d
        self.target = target
        self.time_indices = time_indices
        
        self.length = x_h.shape[0] if x_h is not None else 0
        
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        sample = {}
        
        if self.x_h is not None:
            sample['x_h'] = self.x_h[idx]
        else:
            sample['x_h'] = torch.empty(0)
            
        if self.x_w is not None:
            sample['x_w'] = self.x_w[idx]
        else:
            sample['x_w'] = torch.empty(0)
            
        if self.x_d is not None:
            sample['x_d'] = self.x_d[idx]
        else:
            sample['x_d'] = torch.empty(0)
            
        if self.target is not None:
            sample['target'] = self.target[idx]
        else:
            sample['target'] = torch.empty(0)
            
        if self.time_indices is not None:
            sample['time_indices'] = self.time_indices[idx]
        else:
            sample['time_indices'] = torch.empty(0, dtype=torch.long)
        
        return (sample['x_h'], sample['x_w'], sample['x_d'], 
                sample['target'], sample['time_indices'])

def setup_logging(config):
    log_dir = Path(config['Testing']['test_results_path'])
    log_dir.mkdir(parents=True, exist_ok=True)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = log_dir / f"orion_testing_{timestamp}.log"
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        handlers=[
            logging.FileHandler(log_file, encoding='utf-8'),
            logging.StreamHandler(sys.stdout)
        ]
    )
    
    logger = logging.getLogger(__name__)
    logger.info("Orion Model Testing Started")
    logger.info(f"Log file: {log_file}")
    
    return logger

def load_test_data(config, logger):
    logger.info("Loading test dataset...")
    
    graph_signal_matrix_filename = config['Data']['graph_signal_matrix_filename']
    base_dir = os.path.dirname(graph_signal_matrix_filename)
    base_name = os.path.basename(graph_signal_matrix_filename).split('.')[0]
    
    num_hours = config['Data']['num_of_hours']
    num_days = config['Data']['num_of_days'] 
    num_weeks = config['Data']['num_of_weeks']
    data_file = os.path.join(base_dir, f'{base_name}_r{num_hours}_d{num_days}_w{num_weeks}_Orion.npz')
    
    if not os.path.exists(data_file):
        logger.error(f"Data file not found: {data_file}")
        raise FileNotFoundError(f"Data file not found: {data_file}")
    
    logger.info(f"Loading from: {data_file}")
    data = np.load(data_file)
    
    test_data = {}
    test_data['x_h'] = torch.from_numpy(data['test_x_h']).float() if 'test_x_h' in data else None
    test_data['x_w'] = torch.from_numpy(data['test_x_w']).float() if 'test_x_w' in data else None
    test_data['x_d'] = torch.from_numpy(data['test_x_d']).float() if 'test_x_d' in data else None
    test_data['target'] = torch.from_numpy(data['test_target']).float() if 'test_target' in data else None
    test_data['time_indices'] = torch.from_numpy(data['test_time_indices']).long() if 'test_time_indices' in data else None
    
    stats = {}
    
    stats['min_flow'] = float(data['train_min_flow']) if 'train_min_flow' in data else 0.0
    stats['max_flow'] = float(data['train_max_flow']) if 'train_max_flow' in data else 1.0
    stats['mean_flow'] = float(data['train_mean_flow']) if 'train_mean_flow' in data else 0.0
    stats['std_flow'] = float(data['train_std_flow']) if 'train_std_flow' in data else 1.0
    
    stats['feature_stats'] = {}
    
    feature_keys = [key for key in data.keys() if key.startswith('train_feature_') and key.endswith('_min')]
    
    if feature_keys:
        for key in feature_keys:
            parts = key.split('_')
            if len(parts) >= 3:
                feature_idx = parts[2]
                try:
                    feature_num = int(feature_idx)
                    feature_stat = {
                        'min': float(data[f'train_feature_{feature_idx}_min']),
                        'max': float(data[f'train_feature_{feature_idx}_max']),
                        'mean': float(data[f'train_feature_{feature_idx}_mean']),
                        'std': float(data[f'train_feature_{feature_idx}_std'])
                    }
                    
                    if f'train_feature_{feature_idx}_median' in data:
                        feature_stat['median'] = float(data[f'train_feature_{feature_idx}_median'])
                    if f'train_feature_{feature_idx}_iqr' in data:
                        feature_stat['iqr'] = float(data[f'train_feature_{feature_idx}_iqr'])
                    
                    stats['feature_stats'][feature_num] = feature_stat
                except ValueError:
                    continue
    
    adj_matrix = data['adj_matrix'] if 'adj_matrix' in data else None
    
    logger.info("Test data loaded successfully")
    if test_data['x_h'] is not None:
        logger.info(f"Test x_h shape: {test_data['x_h'].shape}")
    if test_data['x_w'] is not None:
        logger.info(f"Test x_w shape: {test_data['x_w'].shape}")
    if test_data['x_d'] is not None:
        logger.info(f"Test x_d shape: {test_data['x_d'].shape}")
    if test_data['target'] is not None:
        logger.info(f"Test target shape: {test_data['target'].shape}")
    
    return test_data, stats, adj_matrix

def compute_metrics(predictions, targets, stats):
    use_robust = 'median' in stats.get('feature_stats', {}).get(0, {})
    
    if use_robust:
        if 'median' in stats.get('feature_stats', {}).get(0, {}):
            median_flow = stats['feature_stats'][0]['median']
            iqr_flow = stats['feature_stats'][0]['iqr']
            
            if iqr_flow > 1e-8:
                pred_denorm = predictions * iqr_flow + median_flow
                target_denorm = targets * iqr_flow + median_flow
            else:
                mean_flow = stats['mean_flow']
                std_flow = stats['std_flow']
                if std_flow > 1e-8:
                    pred_denorm = predictions * std_flow + mean_flow
                    target_denorm = targets * std_flow + mean_flow
                else:
                    pred_denorm = predictions
                    target_denorm = targets
        else:
            mean_flow = stats['mean_flow']
            std_flow = stats['std_flow']
            if std_flow > 1e-8:
                pred_denorm = predictions * std_flow + mean_flow
                target_denorm = targets * std_flow + mean_flow
            else:
                pred_denorm = predictions
                target_denorm = targets
    else:
        min_flow = stats['min_flow']
        max_flow = stats['max_flow']
        
        if max_flow - min_flow > 1e-8:
            pred_denorm = predictions * (max_flow - min_flow) + min_flow
            target_denorm = targets * (max_flow - min_flow) + min_flow
        else:
            pred_denorm = predictions
            target_denorm = targets
    
    mae = F.l1_loss(pred_denorm, target_denorm).item()
    
    mse = F.mse_loss(pred_denorm, target_denorm)
    rmse = torch.sqrt(mse).item()
    
    threshold = 5
    mask = torch.abs(target_denorm) > threshold
    
    if mask.sum() > 0:
        valid_targets = target_denorm[mask]
        valid_preds = pred_denorm[mask]
        mape = torch.mean(torch.abs((valid_targets - valid_preds) / torch.abs(valid_targets))) * 100
        mape = mape.item()
    else:
        mape = 0.0
    
    if mape > 1000:
        epsilon = 1.0
        mape_safe = torch.mean(torch.abs((target_denorm - pred_denorm) / (torch.abs(target_denorm) + epsilon))) * 100
        mape = min(mape, mape_safe.item())
    
    return {'mae': mae, 'rmse': rmse, 'mape': mape}

def create_test_data_loaders(test_data, config, logger):
    logger.info("Creating test data loaders...")
    
    batch_size = int(config['Testing']['test_batch_size'])
    
    if test_data['x_w'] is not None and test_data['x_w'].dim() == 5:
        if test_data['x_w'].shape[0] != test_data['x_h'].shape[0]:
            test_data['x_w'] = test_data['x_w'].permute(1, 0, 2, 3, 4)
            logger.info(f"Adjusted x_w shape: {test_data['x_w'].shape}")
    
    if test_data['x_d'] is not None and test_data['x_d'].dim() == 5:
        if test_data['x_d'].shape[0] != test_data['x_h'].shape[0]:
            test_data['x_d'] = test_data['x_d'].permute(1, 0, 2, 3, 4)
            logger.info(f"Adjusted x_d shape: {test_data['x_d'].shape}")
    
    test_dataset = OrionDataset(
        test_data['x_h'],
        test_data['x_w'],
        test_data['x_d'],
        test_data['target'],
        test_data['time_indices']
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        drop_last=False
    )
    
    logger.info(f"Test batches: {len(test_loader)}")
    logger.info(f"Batch size: {batch_size}")
    logger.info(f"Total samples: {len(test_dataset)}")
    
    return test_loader

def load_trained_model(config, logger, device):
    logger.info("Loading trained MAE best model...")
    
    model_config = {
        'num_of_vertices': int(config['Data']['num_of_vertices']),
        'in_channels': int(config['Data']['in_channels']),
        'target_len': int(config['Data']['target_len']),
        'source_len': int(config['Data']['source_len']),
        'd_model': int(config['Model']['d_model']),
        'd_ff_emb': int(config['Model']['d_ff_emb']),
        'd_ff_belt': int(config['Model']['d_ff_belt']),
        'd_ff_fusion': int(config['Model']['d_ff_fusion']),
        'd_ff_reverse': int(config['Model']['d_ff_reverse']),
        'n_belt_block': int(config['Model']['n_belt_block']),
        'head_s': int(config['Model']['head_s']),
        'head_t': int(config['Model']['head_t']),
        'head_f': int(config['Model']['head_f']),
        'num_time_segments': int(config['Model']['num_time_segments']),
        'dropout': float(config['Model']['dropout']),
        'causal_threshold': float(config['Model']['causal_threshold'])
    }
    
    if 'node_selection_ratio' in config['Model']:
        model_config['node_selection_ratio'] = float(config['Model']['node_selection_ratio'])
    
    model = create_orion_model(model_config)
    
    model_dir = Path(config['Training']['save_path'])
    model_path = model_dir / 'best_model_mae.pth'
    
    if not model_path.exists():
        model_path = model_dir / 'best_model.pth'
        logger.info(f"Loading default best model: {model_path}")
    
    if not model_path.exists():
        logger.error(f"Model file not found: {model_path}")
        raise FileNotFoundError(f"Model file not found: {model_path}")
    
    logger.info(f"Loading model weights from: {model_path}")
    checkpoint = torch.load(model_path, map_location=device, weights_only=True)
    
    saved_state_dict = checkpoint['model_state_dict']
    
    if 'module.' in list(saved_state_dict.keys())[0]:
        new_state_dict = {}
        for k, v in saved_state_dict.items():
            new_state_dict[k.replace('module.', '')] = v
        saved_state_dict = new_state_dict
    
    model.load_state_dict(saved_state_dict, strict=False)
    model = model.to(device)
    model.eval()
    
    logger.info("Model loaded successfully")
    
    return model

def test_epoch(model, test_loader, device, logger, stats):
    model.eval()
    total_loss = 0.0
    total_samples = 0
    
    all_predictions = []
    all_targets = []
    
    total_samples_in_epoch = len(test_loader.dataset)
    processed_samples = 0
    
    with torch.no_grad():
        pbar = tqdm(total=total_samples_in_epoch, 
                   desc="Testing", 
                   unit="samples",
                   leave=False)
        
        for batch_idx, (x_h, x_w, x_d, targets, time_indices) in enumerate(test_loader):
            try:
                x_h = x_h.to(device, non_blocking=True)
                x_w = x_w.to(device, non_blocking=True)
                x_d = x_d.to(device, non_blocking=True)
                targets = targets.to(device, non_blocking=True)
                time_indices = time_indices.to(device, non_blocking=True)
                
                target_len = targets.shape[-1]
                if time_indices.shape[1] != target_len:
                    if time_indices.shape[1] > target_len:
                        time_indices = time_indices[:, :target_len]
                    else:
                        padding_length = target_len - time_indices.shape[1]
                        padding = time_indices[:, -1:].expand(-1, padding_length)
                        time_indices = torch.cat([time_indices, padding], dim=1)
                
                outputs = model(x_h, x_w, x_d, time_indices, perform_intervention=False)
                predictions = outputs['predictions']
                
                pred_loss = F.l1_loss(predictions, targets)
                
                if torch.isnan(pred_loss) or torch.isinf(pred_loss):
                    logger.warning(f"Invalid loss at batch {batch_idx}, skipping")
                    continue
                
                batch_size = targets.size(0)
                total_loss += pred_loss.item() * batch_size
                total_samples += batch_size
                
                all_predictions.append(predictions.cpu())
                all_targets.append(targets.cpu())
                
                processed_samples += batch_size
                pbar.update(batch_size)
                
                pbar.set_postfix({
                    'Loss': f'{pred_loss.item():.4f}',
                    'Batch': f'{batch_idx+1}/{len(test_loader)}'
                })
                
            except Exception as e:
                logger.error(f"Error in test batch {batch_idx}: {str(e)}")
                import traceback
                logger.error(traceback.format_exc())
                continue
        
        pbar.close()
    
    avg_loss = total_loss / total_samples if total_samples > 0 else float('inf')
    
    if all_predictions and all_targets:
        all_pred = torch.cat(all_predictions, dim=0)
        all_tgt = torch.cat(all_targets, dim=0)
        metrics = compute_metrics(all_pred, all_tgt, stats)
    else:
        metrics = {'mae': float('inf'), 'rmse': float('inf'), 'mape': float('inf')}
    
    return {
        'loss': avg_loss,
        'mae': metrics['mae'],
        'rmse': metrics['rmse'], 
        'mape': metrics['mape'],
        'predictions': all_pred if all_predictions else None,
        'targets': all_tgt if all_targets else None
    }

def compute_detailed_metrics(predictions, targets, stats, logger):
    logger.info("Computing detailed metrics...")
    
    overall_metrics = compute_metrics(predictions, targets, stats)
    
    use_robust = 'median' in stats.get('feature_stats', {}).get(0, {})
    
    if use_robust:
        if 'median' in stats.get('feature_stats', {}).get(0, {}):
            median_flow = stats['feature_stats'][0]['median']
            iqr_flow = stats['feature_stats'][0]['iqr']
            
            if iqr_flow > 1e-8:
                pred_denorm = predictions * iqr_flow + median_flow
                target_denorm = targets * iqr_flow + median_flow
            else:
                mean_flow = stats.get('mean_flow', 0.0)
                std_flow = stats.get('std_flow', 1.0)
                if std_flow > 1e-8:
                    pred_denorm = predictions * std_flow + mean_flow
                    target_denorm = targets * std_flow + mean_flow
                else:
                    pred_denorm = predictions
                    target_denorm = targets
        else:
            mean_flow = stats.get('mean_flow', 0.0)
            std_flow = stats.get('std_flow', 1.0)
            if std_flow > 1e-8:
                pred_denorm = predictions * std_flow + mean_flow
                target_denorm = targets * std_flow + mean_flow
            else:
                pred_denorm = predictions
                target_denorm = targets
    else:
        min_flow = stats.get('min_flow', 0.0)
        max_flow = stats.get('max_flow', 1.0)
        
        if max_flow - min_flow > 1e-8:
            pred_denorm = predictions * (max_flow - min_flow) + min_flow
            target_denorm = targets * (max_flow - min_flow) + min_flow
        else:
            pred_denorm = predictions
            target_denorm = targets
    
    pred_np = pred_denorm.cpu().numpy()
    target_np = target_denorm.cpu().numpy()
    
    num_timesteps = pred_np.shape[2]
    per_timestep_metrics = {}
    
    logger.info("Per-timestep Metrics:")
    logger.info(f"{'Timestep':<10} {'MAE':<10} {'RMSE':<10} {'MAPE(%)':<10}")
    logger.info("-" * 40)
    
    for t in range(num_timesteps):
        step_pred = pred_np[:, :, t].flatten()
        step_target = target_np[:, :, t].flatten()
        
        mae = np.mean(np.abs(step_pred - step_target))
        rmse = np.sqrt(np.mean((step_pred - step_target) ** 2))
        
        threshold = 5
        mask = np.abs(step_target) > threshold
        if mask.sum() > 0:
            mape = np.mean(np.abs((step_target[mask] - step_pred[mask]) / np.abs(step_target[mask]))) * 100
        else:
            mape = 0.0
            
        if mape > 1000:
            epsilon = 1.0
            mape_safe = np.mean(np.abs((step_target - step_pred) / (np.abs(step_target) + epsilon))) * 100
            mape = min(mape, mape_safe)
        
        per_timestep_metrics[t] = {
            'mae': mae,
            'rmse': rmse,
            'mape': mape
        }
        
        logger.info(f"{t+1:<10} {mae:<10.4f} {rmse:<10.4f} {mape:<10.2f}")
    
    logger.info("-" * 40)
    
    metrics = {
        'overall': overall_metrics,
        'per_timestep': per_timestep_metrics
    }
    
    logger.info(f"Overall MAE: {overall_metrics['mae']:.4f}")
    logger.info(f"Overall RMSE: {overall_metrics['rmse']:.4f}")
    logger.info(f"Overall MAPE: {overall_metrics['mape']:.2f}%")
    
    return metrics

def run_test(model, test_data, stats, config, logger, device):
    logger.info("Starting model testing...")
    
    test_loader = create_test_data_loaders(test_data, config, logger)
    
    test_results = test_epoch(model, test_loader, device, logger, stats)
    
    if test_results['predictions'] is not None and test_results['targets'] is not None:
        final_predictions = test_results['predictions']
        final_targets = test_results['targets']
        
        logger.info(f"Test completed! Processed {final_predictions.shape[0]} samples")
        
        metrics = compute_detailed_metrics(final_predictions, final_targets, stats, logger)
        
        return metrics
    else:
        logger.error("No valid predictions collected")
        return {}

def parse_arguments():
    parser = argparse.ArgumentParser(description='Orion Model Test Script')
    parser.add_argument('--config', type=str, required=True,
                       help='Configuration file path')
    parser.add_argument('--device', type=str, default='cuda',
                       help='Device (cuda/cpu)')
    parser.add_argument('--log_level', type=str, default='INFO',
                       choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
                       help='Log level')
    
    return parser.parse_args()

def load_config(config_path):
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Configuration file not found: {config_path}")
    
    config = configparser.ConfigParser()
    config.read(config_path, encoding='utf-8')
    return config

def main():
    args = parse_arguments()
    
    config = load_config(args.config)
    
    device = torch.device(args.device if torch.cuda.is_available() and args.device.startswith('cuda') else 'cpu')
    print(f"Using device: {device}")
    
    logger = setup_logging(config)
    logger.setLevel(getattr(logging, args.log_level))
    
    try:
        logger.info("Testing MAE best model")
        
        test_data, stats, adj_matrix = load_test_data(config, logger)
        
        model = load_trained_model(config, logger, device)
        
        metrics = run_test(model, test_data, stats, config, logger, device)
        
        if metrics:
            logger.info("Testing completed successfully!")
            logger.info("Final Overall Metrics:")
            logger.info(f"  MAE: {metrics['overall']['mae']:.4f}")
            logger.info(f"  RMSE: {metrics['overall']['rmse']:.4f}")
            logger.info(f"  MAPE: {metrics['overall']['mape']:.2f}%")
        else:
            logger.error("Testing failed, no valid metrics generated")
            
    except Exception as e:
        logger.error(f"Error during testing: {str(e)}")
        import traceback
        logger.error(traceback.format_exc())
        sys.exit(1)
    
    logger.info("Testing script finished")
    
if __name__ == "__main__":
    main()