import numpy as np
import os
import torch
import argparse
from .metrics import ECELoss, AdaptiveECELoss, ClasswiseECELoss
import torch.nn as nn

import sys
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from Data import cifar10_c, cifar100_c, tiny_imagenet_c, imagenet_c
from .metrics import test_classification_net_logits

# 定义15种破坏模式
corruption_types = [
    'gaussian_noise',
    'shot_noise', 
    'impulse_noise',
    'defocus_blur',
    'glass_blur',
    'motion_blur',
    'zoom_blur',
    'snow',
    'frost',
    'fog',
    'brightness',
    'contrast',
    'elastic_transform',
    'pixelate',
    'jpeg_compression'
]

nll_criterion = nn.CrossEntropyLoss().cuda()
ece_criterion = ECELoss().cuda()
adaece_criterion = AdaptiveECELoss().cuda()
cece_criterion = ClasswiseECELoss().cuda()

def getMetrics(logits, labels):
    ece = ece_criterion(logits, labels).item()
    adaece = adaece_criterion(logits, labels).item()
    cece = cece_criterion(logits, labels).item()
    nll = nll_criterion(logits, labels).item()
    return ece, adaece, cece, nll

def get_logits_labels(data_loader, net):
    logits_list = []
    labels_list = []
    net.eval()
    with torch.no_grad():
        for data, label in data_loader:
            data = data.cuda()
            if hasattr(net, 'architecture') and net.architecture == 'CNN':
                logits, _ = net(data)
            else:
                logits = net(data)
            logits_list.append(logits)
            labels_list.append(label)
        logits = torch.cat(logits_list).cuda()
        labels = torch.cat(labels_list).cuda()
    return logits, labels

def corruption_test(args, net, scaled_net, clean_results):

    num_corruptions = len(corruption_types)
    severities = [0, 1, 2, 3, 4, 5]  
    num_severities = len(severities)
    
    METRIC_NAMES = ['accuracy', 'ece', 'adaece', 'cece', 'nll']
    MODEL_TYPES = ['raw', 'ts']

    results = np.zeros((2, num_corruptions, num_severities, 5))
    print(f"Results array shape: {results.shape}")
    print("Dimensions: [model_type(2), corruption_type(15), severity(5), metrics(5)]")
    print(f"Model types: {MODEL_TYPES}")
    print(f"Metrics: {METRIC_NAMES}")
    print(f"Corruption types: {corruption_types}")    

    for corruption_idx in range(num_corruptions):

        for metric_idx, value in enumerate(clean_results['raw']):
            results[0, corruption_idx, 0, metric_idx] = value
        
        for metric_idx, value in enumerate(clean_results['ts']):
            results[1, corruption_idx, 0, metric_idx] = value

    input_size = args.vit_input_size if args.model.startswith('vit_') else None

    for corruption_idx, corruption_type in enumerate(corruption_types):
        print(f"\nTesting corruption: {corruption_type}")

        if args.dataset == 'cifar10':
            test_loaders = [cifar10_c.get_test_loader(
                    batch_size=args.test_batch_size,
                    corruption_type=corruption_type, 
                    severity=i,
                    input_size=input_size) for i in range(1,6)]

        elif args.dataset == 'cifar100':
            test_loaders = [cifar100_c.get_test_loader(
                    batch_size=args.test_batch_size,
                    corruption_type=corruption_type, 
                    severity=i,
                    input_size=input_size) for i in range(1,6)]
        elif args.dataset == 'tiny_imagenet':
            class_map_file = "./dataset/tiny_imagenet/wnids.txt"
            test_loaders = [tiny_imagenet_c.get_tiny_imagenet_c_loader(
                    root=tiny_imagenet_c.data_root,
                    corruption=corruption_type,
                    severity=i,
                    batch_size=args.test_batch_size,
                    input_size=input_size,
                    class_map_file=class_map_file) for i in range(1,6)]
        elif args.dataset == 'imagenet':
            test_loaders = [imagenet_c.get_imagenet_c_loader(
                    root=imagenet_c.data_root,
                    corruption=corruption_type,
                    severity=i,
                    batch_size=args.test_batch_size,
                    input_size=input_size) for i in range(1,6)]
        else:
            print(f"Dataset {args.dataset} not supported for corruption testing")
            continue   
            
        for severity_level in range(1, 6):
            severity_idx = severity_level
            print(f"  Severity {severity_level}/5:")
            
            logits, labels = get_logits_labels(test_loaders[severity_level-1], net)
            ece, adaece, cece, nll = getMetrics(logits, labels)
            _, accuracy, _, _, _ = test_classification_net_logits(logits, labels)
            
            results[0, corruption_idx, severity_idx, 0] = accuracy  # accuracy
            results[0, corruption_idx, severity_idx, 1] = ece       # ece
            results[0, corruption_idx, severity_idx, 2] = adaece    # adaece
            results[0, corruption_idx, severity_idx, 3] = cece      # cece
            results[0, corruption_idx, severity_idx, 4] = nll       # nll            
            print(f'    Raw  - Acc: {accuracy*100:.2f}%  NLL: {nll:.5f}  ECE: {ece:.5f}  AdaECE: {adaece:.5f}  CECE: {cece:.5f}')

            logits, labels = get_logits_labels(test_loaders[severity_level-1], scaled_net)
            ece, adaece, cece, nll = getMetrics(logits, labels)
            _, accuracy, _, _, _ = test_classification_net_logits(logits, labels)
            
            results[1, corruption_idx, severity_idx, 0] = accuracy  # accuracy
            results[1, corruption_idx, severity_idx, 1] = ece       # ece
            results[1, corruption_idx, severity_idx, 2] = adaece    # adaece
            results[1, corruption_idx, severity_idx, 3] = cece      # cece
            results[1, corruption_idx, severity_idx, 4] = nll       # nll
            print(f'    TS   - Acc: {accuracy*100:.2f}%  NLL: {nll:.5f}  ECE: {ece:.5f}  AdaECE: {adaece:.5f}  CECE: {cece:.5f}')
        
        test_loaders.clear()
        torch.cuda.empty_cache()
    
    results_dir = 'corruption_results'
    os.makedirs(results_dir, exist_ok=True)
    
    if args.model.startswith('vit_'):
        dataset_folder_name = args.dataset + '_vit'
    else:
        dataset_folder_name = args.dataset
    
    dataset_dir = os.path.join(results_dir, dataset_folder_name)
    os.makedirs(dataset_dir, exist_ok=True)

    np.save(os.path.join(dataset_dir, args.model + '_'  + args.remark + '.npy'), results)
    
    for model_idx, model_type in enumerate(MODEL_TYPES):
        print(f"\n{model_type.upper()} Model Results (Average across all corruption types):")
        print("Severity\tAccuracy\tECE\t\tAdaECE\t\tCECE\t\tNLL")
        print("-" * 70)
        
        for severity_idx, severity in enumerate(severities):
            severity_results = []
            for metric_idx, metric_name in enumerate(METRIC_NAMES):
                avg_value = np.mean(results[model_idx, :, severity_idx, metric_idx])
                severity_results.append(avg_value)
            
            acc, ece, adaece, cece, nll = severity_results
            severity_label = "Clean" if severity_idx == 0 else str(severity_idx)
            print(f"{severity_label}\t\t{acc*100:.2f}%\t\t{ece:.5f}\t{adaece:.5f}\t{cece:.5f}\t{nll:.5f}")
    
    
    return results