

import logging
import re
import warnings
import wandb

logging.captureWarnings(True)
warnings.filterwarnings('always', category=DeprecationWarning,
                        module=r'^{0}.'.format(re.escape(__name__)))
warnings.warn("This is a DeprecationWarning",category=DeprecationWarning)

import argparse
import os
import torch
import random
import numpy as np

from cdc.args import parse_cfg, get_model, get_strong_transformations,\
    get_val_transformations, get_standard_transformations,\
    get_train_dataloader, get_val_dataloader,\
    get_train_dataset,get_val_dataset, get_optimizer
from cdc.utils.evaluate_utils import get_predictions, \
    hungarian_evaluate, calibration_evaluate, hign_conf_evaluation

from cdc.methods.calibrate_train import initialize_weights, train_cali
from cdc.backbones.models import CaliMLP
FLAGS = argparse.ArgumentParser(description='CDC Model')
FLAGS.add_argument('--config_env', default='scripts/cdc/env.yaml', help='Location of path config file')
FLAGS.add_argument('--config_exp', default='scripts/cdc/imagenet10/cdc_s50_c150_lr0001_bs500_su125_seed5.yaml', help='Location of experiments config file')

def main():
    args = FLAGS.parse_args()
    cfg = parse_cfg(args.config_env, args.config_exp)
    print(cfg)
    # Data
    print('Get dataset and dataloaders')
    strong_transformations = get_strong_transformations(cfg)
    standard_transformations = get_standard_transformations(cfg)
    val_transformations = get_val_transformations(cfg)

    train_dataset = get_train_dataset(cfg, {'val': val_transformations,
                                            'standard': standard_transformations,
                                            'augment': strong_transformations},
                                        split=cfg['data']['split'], augmented = True)
    val_dataset = get_val_dataset(cfg, val_transformations)
    train_dataloader = get_train_dataloader(cfg, train_dataset)
    val_dataloader = get_val_dataloader(cfg, val_dataset)
    print('Strong transforms:', strong_transformations)
    print('Standard transforms:', standard_transformations)
    print('Validation transforms:', val_transformations)
    print('Train samples %d - Val samples %d' %(len(train_dataset), len(val_dataset)))

    # Model
    print('Get model')
    model = get_model(cfg, cfg['pretext']['enable'])

    cali_mlp = CaliMLP(cfg)
    cali_mlp = torch.nn.DataParallel(cali_mlp)
    cali_mlp = cali_mlp.cuda()

    # Optimizer
    print('Get optimizer')
    optimizer_clu = get_optimizer(cfg, model)
    optimizer_cali = torch.optim.Adam(cali_mlp.parameters(), lr=cfg['optimizer']['lr'],
                                      **cfg['optimizer']['kwargs'])
    # wandb
    wandb.watch(model, log="all")

    # Checkpoint
    if os.path.exists(cfg['cdc_checkpoint']):
        print('Restart from checkpoint {}'.format(cfg['cdc_checkpoint']))
        checkpoint = torch.load(cfg['cdc_checkpoint'], map_location='cpu')
        #import pdb; pdb.set_trace()
        model.load_state_dict(checkpoint['model'])
        cali_mlp.load_state_dict(checkpoint['cali_mlp'])
        #optimizer_clu.load_state_dict(checkpoint['optimizer_clu'])
        #optimizer_cali.load_state_dict(checkpoint['optimizer_cali'])
        start_epoch = checkpoint['epoch']
    else:
        print('No checkpoint file at {}'.format(cfg['cdc_checkpoint']))
        start_epoch = 0

    # Evaluate
    predictions, features = get_predictions(cfg, val_dataloader, model, return_features=True)
    clustering_stats = hungarian_evaluate(cfg, cfg['cdc_dir'], start_epoch, 0,
                                          predictions, title=cfg['cluster_eval']['plot_title'],
                                          compute_confusion_matrix=False)
    print('CDC-Clu ', clustering_stats)
    predictions = get_predictions(cfg, val_dataloader, model, cali_mlp = cali_mlp)
    clustering_stats = calibration_evaluate(cfg, cfg['cdc_dir'], 0, 0, predictions,
                                            title=cfg['cluster_eval']['plot_title'],
                                            compute_confusion_matrix=False)
    print('CDC-Cal ', clustering_stats)

    #hign_conf_evaluation(cfg, model, train_dataloader, clustering_stats)

    # Initialize weights
    if start_epoch == 0:
        initialize_weights(cfg, model, cali_mlp, features, val_dataloader)
    # Main loop
    print('Starting main loop', 'blue')
    best_acc = -1
    for epoch in range(start_epoch, cfg['max_epochs']):
        print('Epoch %d/%d' % (epoch + 1, cfg['max_epochs']))
        # Train
        print('Train ...')
        train_cali(cfg, train_dataloader, cali_mlp, model, optimizer_cali, optimizer_clu, epoch, start_epoch)
        # Evaluate
        if (epoch+1) % 1 == 0:
            print('Make prediction on validation set ...')
            predictions = get_predictions(cfg, val_dataloader, model)
            clustering_stats = hungarian_evaluate(cfg, cfg['cdc_dir'], epoch, 0, predictions,
                                                  title=cfg['cluster_eval']['plot_title'],
                                                  compute_confusion_matrix=False)
            print('CDC-Clu ', clustering_stats)
            predictions = get_predictions(cfg, val_dataloader, model, cali_mlp = cali_mlp)
            clustering_stats = calibration_evaluate(cfg, cfg['cdc_dir'], epoch, 0, predictions,
                                                  title=cfg['cluster_eval']['plot_title'],
                                                  compute_confusion_matrix=False)
            print('CDC-Cal ', clustering_stats)
        # Checkpoint
        print('Checkpoint ...')
        torch.save({'optimizer_clu': optimizer_clu.state_dict(),
                    'optimizer_cali': optimizer_cali.state_dict(),
                    'model': model.state_dict(),
                    'cali_mlp': cali_mlp.state_dict(),
                    'epoch': epoch + 1},
                   cfg['cdc_checkpoint'])
        if best_acc<clustering_stats['ACC']:
            torch.save({
                        'model': model.state_dict(),
                        'cali_mlp': cali_mlp.state_dict(),
                        'epoch': epoch + 1},
                       cfg['cdc_best_model'])
            best_acc = clustering_stats['ACC']
            
            
    if cfg['data']['dataset'] in ["cifar200"]:
        print("cifar-20-c")
        from cdc.data.cifar10c_dataset import get_cifar20c_dataloader, get_all_cifar20c_dataloaders
        cifar10c_root = '/nas/datasets/CIFAR-100-C'
    
        for i in range(1, 6):
            dataloader = get_cifar20c_dataloader(
                cifar10c_root, 
                corruption_type='pixelate', 
                severity=i, 
                batch_size=500,
                num_workers=cfg['data']['num_workers']
            )
            
            # Print dataset statistics
            print(f"Dataset size: {len(dataloader.dataset)}")
            predictions = get_predictions(cfg, dataloader, model)
            clustering_stats = hungarian_evaluate(cfg, cfg['cdc_dir'], 200, 0, predictions,
                                                  title=cfg['cluster_eval']['plot_title'],
                                                  compute_confusion_matrix=False)
            print('CDC-Clu ', clustering_stats)
            predictions = get_predictions(cfg, dataloader, model, cali_mlp = cali_mlp)
            clustering_stats = calibration_evaluate(cfg, cfg['cdc_dir'], 200, 0, predictions,
                                                  title=cfg['cluster_eval']['plot_title'],
                                                  compute_confusion_matrix=False)
            
            print('CDC-Cal ', clustering_stats)
            
    if cfg['data']['dataset'] in ["cifar20"]:
        print("cifar-20-auroc")
        labels = []
        scores = []
        
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2023, 0.1994, 0.2010]
        import torchvision.transforms as transforms
        
        id_dataset = get_val_dataset(cfg, val_transformations)
        id_dataloader = get_val_dataloader(cfg, id_dataset, 64)
        
        cfg_ = cfg
        cfg_['data']['dataset'] = 'tinyimagenet'
        cfg_['data']['val_path'] = '/nas/datasets/tiny-imagenet-200/val/images_new'
        ood_dataset = get_val_dataset(cfg_, val_transformations)
        ood_dataloader = get_val_dataloader(cfg_, ood_dataset, 64)

        #                    cifar10            cifar20    
        #                auroc    FPR95      auroc    FPR95
        # tinyimagenet: 
        # svhn:         
        
        model.eval()
        for i, batch in enumerate(id_dataloader):
            images, targets = batch[0], batch[1] 
            outputs = model(images.cuda(non_blocking=True),
                        forward_pass='return_all')['output'][0]
            probs = torch.softmax(outputs, dim=1)  # 获取 softmax 概率
            max_probs, _ = torch.max(probs, dim=1)  # 取最大类别的概率
            labels.extend([1] * len(targets))  # 全部标记为1
            scores.extend(max_probs.detach().cpu().numpy())  # 预测置信度
            
        for i, batch in enumerate(ood_dataloader):
            images, targets = batch[0], batch[1] 
            outputs = model(images.cuda(non_blocking=True),
                        forward_pass='return_all')['output'][0]
            probs = torch.softmax(outputs, dim=1)  # 获取 softmax 概率
            max_probs, _ = torch.max(probs, dim=1)  # 取最大类别的概率
            
            labels.extend([0] * len(targets))  # 全部标记为1
            scores.extend(max_probs.detach().cpu().numpy())  # 预测置信度
            
        # 计算 AUROC
        from sklearn.metrics import roc_auc_score, roc_curve
        auroc = roc_auc_score(labels, scores)
        # 计算 ROC 曲线
        fpr, tpr, thresholds = roc_curve(labels, scores)
        # 找到 TPR 最接近 95% 的索引
        target_tpr = 0.95
        idx = np.argmin(np.abs(tpr - target_tpr))
        # 对应的 FPR 值
        fpr95 = fpr[idx]
        print(f"FPR95: {fpr95:.4f}")
        print(f"AUROC Score: {auroc:.4f}")

if __name__ == "__main__":
    seed = 5
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.set_grad_enabled(True)
    main()
    wandb.finish()