import argparse
import os
import time
import numpy as np
import torch
import csv 
from robustbench.data import load_cifar100c, load_cifar100
from data_cifar import load_svhn, load_Places365, load_LSUN, load_iSUN, load_textures
from data_cifar import load_svhn_c, load_Places365_c, load_LSUN_c, load_iSUN_c, load_textures_c
from robustbench.model_zoo.enums import ThreatModel
from robustbench.utils import load_model
from data_tinyimagenet import load_tiny_imagenet, load_tiny_imagenet_c, load_imagenet_o, load_imagenet_o_c
from data_tinyimagenet import load_Places, load_Places_c, load_iNaturalist, load_iNaturalist_c, load_SUN, load_SUN_c
import torchvision.models as models
from robustbench.model_zoo.architectures.utils_architectures import normalize_model
from metrics.online_rds import evaluate_ood_scores 
from utils import get_logger, set_random_seed
import math
from copy import deepcopy
from transformers import SwinForImageClassification, ViTForImageClassification
import torch.nn as nn

def log_method_results(logger, results, corruption, elapsed):
    """
    Function to log results of all OOD detection methods
    
    Args:
        logger: logging object
        results: evaluation results dictionary
        corruption: corruption type name
        elapsed: time taken for evaluation
    """
    logger.info(f"Results for {corruption} (took {elapsed:.2f}s):")
    
    # Display OOD detection results in table format
    from prettytable import PrettyTable
    
    # Create OOD detection results table
    ood_table = PrettyTable()
    ood_table.field_names = ["Method", "First AUROC", "Final AUROC", "Final FPR@95TPR", "Avg AUROC", "Avg FPR@95TPR"]
    
    for method_name in results['ood_detection'].keys():
        method_results = results['ood_detection'][method_name]
        ood_table.add_row([
            method_name.upper(),
            f"{method_results['batch_aurocs'][0]:.4f}",
            f"{method_results['final_auroc']:.4f}",
            f"{method_results['final_fpr']:.4f}",
            f"{method_results['avg_auroc']:.4f}",
            f"{method_results['avg_fpr']:.4f}"
        ])
    
    logger.info("\nOOD Detection Performance:")
    logger.info(f"\n{ood_table}")
    
    # OSCR performance table
    if 'oscr' in results:
        oscr_table = PrettyTable()
        oscr_table.field_names = ["Metric", "Final", "Average"]
        oscr_table.add_row([
            "OSCR",
            f"{results['oscr']['final_oscr']:.4f}",
            f"{results['oscr']['avg_oscr']:.4f}"
        ])
        logger.info("\nOSCR Performance:")
        logger.info(f"\n{oscr_table}")
    
    # Classification performance table
    if 'classification' in results:
        cls_table = PrettyTable()
        cls_table.field_names = ["Metric", "Final", "Average"]
        cls_table.add_row([
            "Accuracy",
            f"{results['classification']['final_accuracy']:.4f}",
            f"{results['classification']['avg_accuracy']:.4f}"
        ])
        logger.info("\nClassification Performance:")
        logger.info(f"\n{cls_table}")

def log_summary_results(logger, all_results, corruptions):
    """
    Function to log summary results across all corruption types
    
    Args:
        logger: logging object
        all_results: dictionary of results for all corruption types
        corruptions: list of corruption types
    """
    from prettytable import PrettyTable
    
    logger.info("\nSummary across all corruption types:")
    
    # Dynamically extract OOD method list
    ood_methods = list(all_results[corruptions[0]]['ood_detection'].keys())
    
    # Create OOD Detection results table
    ood_table = PrettyTable()
    ood_table.field_names = ["Method", "Avg Batch AUROC", "Avg Batch FPR@95TPR"]
    
    for method in ood_methods:
        # Average batch performance statistics
        avg_aurocs = [r['ood_detection'][method]['avg_auroc'] for r in all_results.values()]
        avg_fprs = [r['ood_detection'][method]['avg_fpr'] for r in all_results.values()]
        
        ood_table.add_row([
            method.upper(),
            f"{np.mean(avg_aurocs):.4f}",
            f"{np.mean(avg_fprs):.4f}"
        ])
    
    logger.info("\nOOD Detection Results:")
    logger.info(f"\n{ood_table}")
    
    # Classification performance table
    if 'classification' in list(all_results.values())[0]:
        avg_accuracies = [r['classification']['avg_accuracy'] for r in all_results.values()]
        final_accuracies = [r['classification']['final_accuracy'] for r in all_results.values()]
        
        cls_table = PrettyTable()
        cls_table.field_names = ["Metric", "Value"]
        cls_table.add_row(["Average Batch Accuracy", f"{np.mean(avg_accuracies):.4f}"])
        cls_table.add_row(["Final Accuracy", f"{np.mean(final_accuracies):.4f}"])
        
        logger.info("\nClassification Results:")
        logger.info(f"\n{cls_table}")
    
    # OSCR performance table
    if 'oscr' in list(all_results.values())[0]:
        oscr_values = [r['oscr']['avg_oscr'] for r in all_results.values()]
        final_oscr_values = [r['oscr']['final_oscr'] for r in all_results.values()]
        
        oscr_table = PrettyTable()
        oscr_table.field_names = ["Metric", "Value"]
        oscr_table.add_row(["Average Batch OSCR", f"{np.mean(oscr_values):.4f}"])
        oscr_table.add_row(["Final OSCR", f"{np.mean(final_oscr_values):.4f}"])
        
        logger.info("\nOSCR Performance:")
        logger.info(f"\n{oscr_table}")
    
def log_sorted_performance(logger, all_results, corruptions, method, show_trend=False):
    """
    Log corruption type results sorted by performance for a specific OOD detection method
    
    Args:
        logger: logging object
        all_results: dictionary of results for all corruption types
        corruptions: list of corruption types
        method: OOD detection method name
    """
    sorted_results = sorted(
        [(c, all_results[c]['ood_detection'][method]['avg_auroc'], 
          all_results[c]['ood_detection'][method]['avg_fpr']) for c in corruptions],
        key=lambda x: x[1],
        reverse=True
    )
    
    logger.info(f"\nPerformance by corruption type (sorted by {method.upper()} avg_auroc):")
    for corruption, auroc, fpr in sorted_results:
        if show_trend:
            logger.info(f"  {corruption}: {method.upper()} Avg AUROC={auroc:.4f}, {method.upper()} First AUROC={all_results[corruption]['ood_detection'][method]['batch_aurocs'][0]:.4f}, {method.upper()} Final AUROC={all_results[corruption]['ood_detection'][method]['final_auroc']:.4f}, {method.upper()} Avg FPR@95TPR={fpr:.4f}")
        else:
            logger.info(f"  {corruption}: {method.upper()} Avg AUROC={auroc:.4f}, {method.upper()} Avg FPR@95TPR={fpr:.4f}")

def get_summary_results(all_results):
    """
    Function to calculate average results across multiple corruption types
    
    Args:
        all_results: dictionary of results for all corruption types
        
    Returns:
        dict: dictionary containing average results
    """
    # Extract OOD method list from first result
    first_result = next(iter(all_results.values()))
    ood_methods = list(first_result['ood_detection'].keys())
    
    # Initialize summary results dictionary
    summary = {'ood_detection': {}}
    
    # Calculate averages for each OOD method
    for method in ood_methods:
        final_aurocs = [r['ood_detection'][method]['final_auroc'] for r in all_results.values()]
        final_fprs = [r['ood_detection'][method]['final_fpr'] for r in all_results.values()]
        avg_aurocs = [r['ood_detection'][method]['avg_auroc'] for r in all_results.values()]
        avg_fprs = [r['ood_detection'][method]['avg_fpr'] for r in all_results.values()]
        
        summary['ood_detection'][method] = {
            'final_auroc': np.mean(final_aurocs),
            'final_fpr': np.mean(final_fprs),
            'avg_auroc': np.mean(avg_aurocs),
            'avg_fpr': np.mean(avg_fprs),
            'batch_aurocs': [],  # Empty array (not used in summary)
            'batch_fprs': []     # Empty array (not used in summary)
        }
    
    # OSCR summary (if available)
    if 'oscr' in first_result:
        final_oscrs = [r['oscr']['final_oscr'] for r in all_results.values()]
        avg_oscrs = [r['oscr']['avg_oscr'] for r in all_results.values()]
        
        summary['oscr'] = {
            'final_oscr': np.mean(final_oscrs),
            'avg_oscr': np.mean(avg_oscrs),
            'batch_values': []  # Empty array (not used in summary)
        }
    
    # Classification performance summary (if available)
    if 'classification' in first_result:
        final_accs = [r['classification']['final_accuracy'] for r in all_results.values()]
        avg_accs = [r['classification']['avg_accuracy'] for r in all_results.values()]
        
        summary['classification'] = {
            'final_accuracy': np.mean(final_accs),
            'avg_accuracy': np.mean(avg_accs),
            'batch_accuracies': []  # Empty array (not used in summary)
        }
    
    return summary

def save_results_to_csv(results_dict, filepath, args):
    """
    Save performance results by corruption type to CSV file.
    If the file already exists, add data to it.
    
    Args:
        results_dict: dictionary with corruption type as key and result dictionary as value
        filepath: path to save CSV file
        args: argument object containing hyperparameters
    """
    # all OOD methods list (extracted from first result)
    first_result = next(iter(results_dict.values()))
    ood_methods = list(first_result['ood_detection'].keys())
    
    base_fieldnames = [
        'init_methods', 'ema_alpha', 'confidence_threshold', 
        'iqr_factor', 'feature_layer', 'corruption'
    ]
    
    # add OOD method-specific fields
    ood_fieldnames = []
    for method in ood_methods:
        ood_fieldnames.extend([
            f'{method}_final_auroc', f'{method}_final_fpr',
            f'{method}_avg_auroc', f'{method}_avg_fpr'
        ])
    
    # other statistical fields
    other_fieldnames = ['final_oscr', 'avg_oscr', 'final_accuracy', 'avg_accuracy']
    
    # combine all field names
    fieldnames = base_fieldnames + ood_fieldnames + other_fieldnames
    
    # check if file exists
    file_exists = os.path.isfile(filepath)
    
    # open file (append mode if exists, write mode if not) - add UTF-8 encoding
    with open(filepath, 'a' if file_exists else 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        
        # write header only if file is new
        if not file_exists:
            writer.writeheader()
        
        for corruption, results in results_dict.items():
            # basic data
            row_data = {
                'init_methods': args.init_methods,
                'ema_alpha': args.ema_alpha,
                'confidence_threshold': args.confidence_threshold,
                'iqr_factor': args.iqr_factor,
                'feature_layer': args.feature_layer,
                'corruption': corruption
            }
            
            # add OOD method-specific data
            for method in ood_methods:
                method_results = results['ood_detection'][method]
                row_data[f'{method}_final_auroc'] = f"{method_results['final_auroc']:.4f}"
                row_data[f'{method}_final_fpr'] = f"{method_results['final_fpr']:.4f}"
                row_data[f'{method}_avg_auroc'] = f"{method_results['avg_auroc']:.4f}"
                row_data[f'{method}_avg_fpr'] = f"{method_results['avg_fpr']:.4f}"
            
            # add OSCR data
            if 'oscr' in results:
                row_data['final_oscr'] = f"{results['oscr']['final_oscr']:.4f}"
                row_data['avg_oscr'] = f"{results['oscr']['avg_oscr']:.4f}"
            else:
                row_data['final_oscr'] = "N/A"
                row_data['avg_oscr'] = "N/A"
            
            # add classification performance data
            if 'classification' in results:
                row_data['final_accuracy'] = f"{results['classification']['final_accuracy']:.4f}"
                row_data['avg_accuracy'] = f"{results['classification']['avg_accuracy']:.4f}"
            else:
                row_data['final_accuracy'] = "N/A"
                row_data['avg_accuracy'] = "N/A"
                
            writer.writerow(row_data)

def main():
    parser = argparse.ArgumentParser(description='Evaluate Online RDS Method')
    # model
    parser.add_argument('--arch', default='Hendrycks2020AugMix_WRN', 
                        choices=['Hendrycks2020AugMix_WRN', 'resnet50','vit-tiny','swin-tiny'])
    parser.add_argument('--dataset', default='cifar100', choices=['cifar100', 'tiny_imagenet'])
    parser.add_argument('--ood_dataset', default='svhn', choices=['svhn', 'Places365', 'LSUN', 'iSUN','textures','imagenet_o','Places','SUN','iNaturalist'])
    # data
    parser.add_argument('--corruption', default='gaussian_noise', 
                        help='Corruption type for evaluation, or "all" for all corruptions')
    parser.add_argument('--severity', default=5, type=int, choices=[1, 2, 3, 4, 5],
                        help='Corruption severity level')
    parser.add_argument('--num_ex', default=10000, type=int, 
                        help='Number of examples to use from each dataset')
    
    # evaluation
    parser.add_argument('--batch_size', default=100, type=int)
    parser.add_argument('--num_batches', default=0, type=int,
                        help='Number of batches for online evaluation (0 or negative for full dataset)')
    
    # DART setting
    parser.add_argument('--confidence_threshold', default=0.25, type=float,
                        help='Confidence threshold for updates')
    parser.add_argument('--use_confidence', action='store_true',
                        help='Enable confidence-based updates (disabled by default)')
    parser.add_argument('--use_outlier_removal', action='store_true',
                        help='Enable outlier removal (disabled by default)')
    parser.add_argument('--iqr_factor', default=1.5, type=float,
                        help='IQR factor for outlier detection')
    parser.add_argument('--ema_alpha', default=0.9, type=float,
                        help='EMA coefficient for center updates')
    parser.add_argument('--temperature', default=5.0, type=float,
                        help='Temperature scaling')
    parser.add_argument('--flip_weight', default=2.0, type=float,
                        help='Weight for id-ood distance in flip detection')
    parser.add_argument('--space', default='feature', choices=['feature', 'logit'],
                        help='Space for distance calculation (feature or logit)')
    parser.add_argument('--auto_correction', action='store_true',
                        help='Enable automatic confidence-based correction mechanism')
    parser.add_argument('--init_methods', default=['msp'], choices=['msp','energy','entropy'],
                        help='Initialization method (energy or max_prob)')
    parser.add_argument('--target_methods', default="RDS,RDS_equal,Energy,MSP,Max_logit,Entropy,GradNorm,ViM,KNN,Mahalanobis_single,Mahalanobis_ensemble,ODIN",
                        help='Target methods to evaluate')
    parser.add_argument('--feature_layer', default='single', type=str,
                        help='Utilize single/multi-level feature (default: single)')
        
    # other settings
    parser.add_argument('--data_dir', default='./data', help='Data directory')
    parser.add_argument('--ckpt_dir', default='./ckpt', help='Checkpoint directory')
    parser.add_argument('--save_dir', default='./results/total_scores', help='Results directory')
    parser.add_argument('--seed', default=1, type=int, help='Random seed')
    parser.add_argument('--gpu', default=0, type=int, help='GPU ID')
    
    args = parser.parse_args()
    args.target_methods = args.target_methods.split(',')

    # GPU setup
    # os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Set random seed
    # Setup logging
    save_path = os.path.join(args.save_dir, f"{args.arch}_{args.dataset}_{args.ood_dataset}_{args.corruption}_s{args.severity}_seed{args.seed}")
    os.makedirs(save_path, exist_ok=True)
    logger = get_logger(__name__, save_path, 'log.txt')
    logger.info(f"Arguments: {args}")
    
    if args.dataset == 'cifar100':
        num_classes = 100
        if args.arch == 'Hendrycks2020AugMix_WRN':
            base_model = load_model(args.arch, args.ckpt_dir, args.dataset, ThreatModel.corruptions).to(device)
            if args.feature_layer == 'single':
                layer_list = ['block3']
            elif args.feature_layer == 'multi':
                layer_list = ['block1', 'block2', 'block3', 'fc']
        elif args.arch == 'vit-tiny':
            base_model = ViTForImageClassification.from_pretrained("WinKawaks/vit-tiny-patch16-224")
            base_model.classifier = nn.Linear(192, num_classes)
            ckpt_path = os.path.join(args.ckpt_dir, f'checkpoint_{args.dataset}_vit_tiny.pt')
            checkpoint = torch.load(ckpt_path, map_location="cpu")
            base_model.load_state_dict(checkpoint)
            if args.feature_layer == 'single':
                layer_list = ['vit.encoder.layer.11']
            elif args.feature_layer == 'multi':
                layer_list = ['vit.encoder.layer.0','vit.encoder.layer.1','vit.encoder.layer.2','vit.encoder.layer.3',
                            'vit.encoder.layer.4','vit.encoder.layer.5','vit.encoder.layer.6','vit.encoder.layer.7',
                            'vit.encoder.layer.8','vit.encoder.layer.9','vit.encoder.layer.10','vit.encoder.layer.11','classifier']
        elif args.arch == 'swin-tiny':
            base_model = SwinForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
            base_model.classifier = nn.Linear(768, num_classes)
            ckpt_path = os.path.join(args.ckpt_dir, f'checkpoint_{args.dataset}_swin_tiny.pt')
            checkpoint = torch.load(ckpt_path, map_location="cpu")
            base_model.load_state_dict(checkpoint)
            if args.feature_layer == 'single':
                layer_list = ['vit.encoder.layer.11']
            elif args.feature_layer == 'multi':
                layer_list = ['swin.encoder.layers.0.blocks.0','swin.encoder.layers.0.blocks.1',
                              'swin.encoder.layers.1.blocks.0','swin.encoder.layers.1.blocks.1',
                              'swin.encoder.layers.2.blocks.0','swin.encoder.layers.2.blocks.1','swin.encoder.layers.2.blocks.2',
                              'swin.encoder.layers.2.blocks.3','swin.encoder.layers.2.blocks.4','swin.encoder.layers.2.blocks.5',
                              'swin.encoder.layers.3.blocks.0','swin.encoder.layers.3.blocks.1','classifier']

        base_model = base_model.cuda()
        logger.info(f"Loaded model: {args.arch}")
        
  
    elif args.dataset == 'tiny_imagenet':
        args.arch = 'resnet50'
        base_model = models.__dict__[args.arch](num_classes=200)
        checkpoint = torch.load(args.ckpt_dir + "/tiny_imagenet/model_best.pth.tar", map_location="cpu")
        base_model.load_state_dict(checkpoint["state_dict"])
        base_model = normalize_model(base_model, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)).cuda()
        if args.feature_layer == 'single':
            layer_list = ['layer4']
        elif args.feature_layer == 'multi':
            layer_list = ['layer1', 'layer2', 'layer3', 'layer4', 'fc']
    
    base_model.eval()
    model = base_model

    if args.corruption.lower() == 'all':
        corruptions = [
            "gaussian_noise", "shot_noise", "impulse_noise",
            "defocus_blur", "glass_blur", "motion_blur", "zoom_blur",
            "snow", "frost", "fog", "brightness", "contrast",
            "elastic_transform", "pixelate", "jpeg_compression"
        ]

    else:
        corruptions = [args.corruption]

    all_results = {}
    
    for corruption in corruptions:
        set_random_seed(args.seed)
        copied_model = deepcopy(model)  # for episodic
        corruption_save_path = os.path.join(save_path, corruption)
        os.makedirs(corruption_save_path, exist_ok=True)
        
        if corruption == 'original': # clean dataset
            logger.info(f"Evaluating on {corruption}...")
            if args.dataset == 'cifar100':
                x_ind, y_ind = eval(f"load_{args.dataset}c")(args.num_ex, 1, 
                                                    args.data_dir, True, [corruption]) 
                x_ood, _ = eval(f"load_{args.ood_dataset}")(args.num_ex, args.data_dir, size=32, shuffle=True)
                
            elif args.dataset == 'tiny_imagenet':
                x_ind, y_ind = eval(f"load_{args.dataset}")(args.num_ex, args.data_dir, size=64, shuffle=True)
                x_ood, _ = eval(f"load_{args.ood_dataset}")(args.num_ex, args.data_dir, size=64, shuffle=True)
        else: # corrupted dataset
            logger.info(f"Evaluating on {corruption} (severity {args.severity})...")
            if args.dataset == 'cifar100':
                x_ind, y_ind = eval(f"load_{args.dataset}c")(args.num_ex, args.severity, 
                                                    args.data_dir, True, [corruption])  
                # additional dataloaders usage
                x_ood, _ = eval(f"load_{args.ood_dataset}_c")(args.num_ex, args.severity, args.data_dir, 32, True, [corruption])
                
            elif args.dataset == 'tiny_imagenet':
                x_ind, y_ind = eval(f"load_{args.dataset}_c")(args.num_ex, args.severity, 
                                                    args.data_dir, 64, True, [corruption])
                x_ood, _ = eval(f"load_{args.ood_dataset}_c")(args.num_ex, args.severity, args.data_dir, 
                                                64, True, [corruption])
                
            else:
                raise ValueError("Unsupported dataset")
        x_ind, x_ood = x_ind.to(device), x_ood.to(device)
        
        if args.arch in ['vit-tiny','swin-tiny']:
            x_ind = torch.nn.functional.interpolate(x_ind, size=(224, 224), mode='bilinear')
            x_ood = torch.nn.functional.interpolate(x_ood, size=(224, 224), mode='bilinear')
        if y_ind is not None:
            y_ind = y_ind.to(device)
        
        actual_num_batches = args.num_batches
        if args.num_batches <= 0:
            n_ind = x_ind.size(0)
            n_ood = x_ood.size(0)
            actual_num_batches = max(
                math.ceil(n_ind / args.batch_size),
                math.ceil(n_ood / args.batch_size)
            )
            logger.info(f"Using full dataset: {actual_num_batches} batches")
        else:
            logger.info(f"Using {args.num_batches} batches as specified")
            
        logger.info(f"Loaded {x_ind.size(0)} ID and {x_ood.size(0)} OOD samples")
        
        start_time = time.time()
        results = evaluate_ood_scores(
            args=args,
            model=copied_model,
            x_ind_all=x_ind,
            x_ood_all=x_ood,
            y_ind_all=y_ind,  
            dataset=args.dataset,
            batch_size=args.batch_size,
            num_batches=actual_num_batches,
            layer_list=layer_list,
            save_dir=corruption_save_path,
            rds_confidence_threshold=args.confidence_threshold,
            rds_iqr_factor=args.iqr_factor,
            rds_ema_alpha=args.ema_alpha,
            auto_correction=args.auto_correction,
            init_methods=args.init_methods,
            target_methods=args.target_methods,
            temperature=args.temperature,
            flip_weight=args.flip_weight
        )
        elapsed = time.time() - start_time


        log_method_results(logger, results, corruption, elapsed)
        
        all_results[corruption] = results
    
    if len(corruptions) > 1:
        log_summary_results(logger, all_results, corruptions)
        

        for method in all_results[corruptions[0]]['ood_detection'].keys():
            if method == 'RDS':
                log_sorted_performance(logger, all_results, corruptions, method, show_trend=True)
            else:
                log_sorted_performance(logger, all_results, corruptions, method)
    

        csv_path = os.path.join(save_path, f"results.csv")
        save_results_to_csv(all_results, csv_path, args)
        

        with open(csv_path, 'a', newline='', encoding='utf-8') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow([])  
            

        summary_results = {}
        summary_results['AVERAGE'] = get_summary_results(all_results)
        save_results_to_csv(summary_results, csv_path, args)
        
        logger.info(f"Results with summary statistics saved to {csv_path}")

    else:
        csv_path = os.path.join(save_path, f"results.csv")
        save_results_to_csv(all_results, csv_path, args)
        logger.info(f"Results saved to {csv_path}")


if __name__ == '__main__':
    main()
