

import logging
import re
import warnings
import wandb
from torch.utils.data import DataLoader, Subset

logging.captureWarnings(True)
warnings.filterwarnings('always', category=DeprecationWarning,
                        module=r'^{0}.'.format(re.escape(__name__)))
warnings.warn("This is a DeprecationWarning",category=DeprecationWarning)

import argparse
import os
import torch
import random
import numpy as np

from cdc.args import parse_cfg, get_model, get_strong_transformations,\
    get_val_transformations, get_standard_transformations,\
    get_train_dataloader, get_val_dataloader,\
    get_train_dataset,get_val_dataset, get_optimizer
from cdc.utils.evaluate_utils import get_predictions, \
    hungarian_evaluate, calibration_evaluate, hign_conf_evaluation, hungarian_evaluate_hard, save_top_images, save_bottom_images, generate_and_plot_tsne

from cdc.methods.calibrate_train import initialize_weights, train_cali, init_head_with_confident_samples, initialize_weights_v4, initialize_weights_v5, initialize_weights_longtail_v1, initialize_weights_bias
from cdc.backbones.models import CaliMLP
from cdc.data.waterbirds_dataset import compute_worst_acc

FLAGS = argparse.ArgumentParser(description='CDC Model')
FLAGS.add_argument('--config_env', default='scripts/cdc/env.yaml', help='Location of path config file')
FLAGS.add_argument('--config_exp', default='scripts/cdc/cifar20/cdc_ini_bias_a20_seed5.yaml', help='Location of experiments config file')

os.environ["WANDB_API_KEY"] = '2a4485eff00bb9efe7db48f5ca413f10466663b4'
os.environ["WANDB_MODE"]="offline"

def main():
    args = FLAGS.parse_args()
    cfg = parse_cfg(args.config_env, args.config_exp)
    print(cfg)
    # Data
    print('Get dataset and dataloaders')
    strong_transformations = get_strong_transformations(cfg)
    standard_transformations = get_standard_transformations(cfg)
    val_transformations = get_val_transformations(cfg)

    train_dataset = get_train_dataset(cfg, {'val': val_transformations,
                                            'standard': standard_transformations,
                                            'augment': strong_transformations},
                                        split=cfg['data']['split'], augmented = True)
    val_dataset = get_val_dataset(cfg, val_transformations)
    train_dataloader = get_train_dataloader(cfg, train_dataset)
    val_dataloader = get_val_dataloader(cfg, val_dataset)
    
    print('Strong transforms:', strong_transformations)
    print('Standard transforms:', standard_transformations)
    print('Validation transforms:', val_transformations)
    print('Train samples %d - Val samples %d' %(len(train_dataset), len(val_dataset)))
    
    prun_epoch = cfg.get('prun_epoch', -1)
    print('Pruning epoch:', prun_epoch)

    # Model
    print('Get model')
    model = get_model(cfg, cfg['pretext']['enable'])

    cali_mlp = CaliMLP(cfg)
    cali_mlp = torch.nn.DataParallel(cali_mlp)
    cali_mlp = cali_mlp.cuda()

    # Optimizer
    print('Get optimizer')
    optimizer_clu = get_optimizer(cfg, model)
    optimizer_cali = torch.optim.Adam(cali_mlp.parameters(), lr=cfg['optimizer']['lr'],
                                      **cfg['optimizer']['kwargs'])
    # wandb
    wandb.watch(model, log="all")
        
    # Evaluate
    predictions, features = get_predictions(cfg, val_dataloader, model, return_features=True)
    
    clustering_stats, indices = hungarian_evaluate_hard(cfg, cfg['cdc_dir'], 0, 0,
                                        predictions, title=cfg['cluster_eval']['plot_title'],
                                        compute_confusion_matrix=False)
    print('CDC-Clu ', clustering_stats)
    
    predictions = get_predictions(cfg, val_dataloader, model, cali_mlp = cali_mlp)
    clustering_stats = calibration_evaluate(cfg, cfg['cdc_dir'], 0, 0, predictions,
                                            title=cfg['cluster_eval']['plot_title'],
                                            compute_confusion_matrix=False)
    print('CDC-Cal ', clustering_stats)

    # Initialize weights
    initialize_weights(cfg, model, cali_mlp, features, val_dataloader)           
    predictions, features = get_predictions(cfg, val_dataloader, model, return_features=True)  
    clustering_stats, indices = hungarian_evaluate_hard(cfg, cfg['cdc_dir'], 0, 0,
                                        predictions, title=cfg['cluster_eval']['plot_title'],
                                        compute_confusion_matrix=False, easy=True)
    print('CDC-Clu-ini: ', clustering_stats) 
    
    log_path = os.path.join(cfg['cdc_dir'], 'training_log.log')
    log_file = open(log_path, 'a')
    log_file.write(f'CDC-Clu-ini: {clustering_stats}\n')

    alpha= cfg.get('alpha', 1.0)
    initialize_weights_bias(cfg, model, cali_mlp, features, val_dataloader, alpha=alpha)            
    predictions, features = get_predictions(cfg, val_dataloader, model, return_features=True)
    clustering_stats, indices = hungarian_evaluate_hard(cfg, cfg['cdc_dir'], 0, 0,
                                        predictions, title=cfg['cluster_eval']['plot_title'],
                                        compute_confusion_matrix=False, easy=True)
    print('CDC-Clu-ini-bias: ', clustering_stats)

    torch.save({'optimizer_clu': optimizer_clu.state_dict(),
                    'optimizer_cali': optimizer_cali.state_dict(),
                    'model': model.state_dict(),
                    'cali_mlp': cali_mlp.state_dict(),
                    'epoch': 0},
                    os.path.join(cfg['cdc_dir'], "checkpoint_highinin.pth.tar"))

    #init_head_with_confident_samples(model, cali_mlp, features, predictions[0]['probabilities'], cfg['backbone']['nclusters'], top_ratio=0.9, confidence_offset=0.1, balanced_per_class=True)
    #init_head_with_confident_samples(model, cali_mlp, features, predictions[0]['probabilities'], cfg['backbone']['nclusters'], top_ratio=0.8, confidence_offset=0.1, balanced_per_class=True)
    #initialize_weights_v4(cfg, model, cali_mlp, features, top_k_percent=0.3)
    #initialize_weights_v5(cfg, model, cali_mlp, features, val_dataloader)
    
    #initialize_weights_longtail_v1(cfg, model, cali_mlp, features, val_dataloader, ep=cfg['epsilon'])
    
    """ predictions, features = get_predictions(cfg, val_dataloader, model, return_features=True)
    
    clustering_stats, indices = hungarian_evaluate_hard(cfg, cfg['cdc_dir'], 0, 0,
                                        predictions, title=cfg['cluster_eval']['plot_title'],
                                        compute_confusion_matrix=False)
    print('highconf-CDC-Clu-ini ', clustering_stats) 
    
    log_file.write(f'CDC-Clu-ini-high: {clustering_stats}\n')
    torch.save({'optimizer_clu': optimizer_clu.state_dict(),
                    'optimizer_cali': optimizer_cali.state_dict(),
                    'model': model.state_dict(),
                    'cali_mlp': cali_mlp.state_dict(),
                    'epoch': 0},
                    os.path.join(cfg['cdc_dir'], "checkpoint_highinin.pth.tar")) """
    
    """ save_dict = {
        'features': features.cpu().numpy().astype(np.float32),
        'targets': predictions[0]['targets'].cpu().numpy().astype(np.int32),
        'predictions': predictions[0]['predictions'].cpu().numpy().astype(np.int32),
        'highlight': np.array(indices, dtype=np.int32)
    }
    save_file = os.path.join(cfg['cdc_dir'], f'hard_features_ini_cdc_cifar10_1590.npz')
    np.savez_compressed(save_file, **save_dict)
    print(f"Features saved to: {save_file}") """   
    
    """ generate_and_plot_tsne(cfg, val_dataloader, model, indices_to_highlight=indices, save_path=cfg['cdc_dir'], epoch=0) """
    
    # Checkpoint
    if os.path.exists(cfg['cdc_checkpoint']):
        print('Restart from checkpoint {}'.format(cfg['cdc_checkpoint']))
        checkpoint = torch.load(cfg['cdc_checkpoint'], map_location='cpu')
        #import pdb; pdb.set_trace()
        model.load_state_dict(checkpoint['model'])
        cali_mlp.load_state_dict(checkpoint['cali_mlp'])
        #optimizer_clu.load_state_dict(checkpoint['optimizer_clu'])
        #optimizer_cali.load_state_dict(checkpoint['optimizer_cali'])
        start_epoch = checkpoint['epoch']
        
    else:
        print('No checkpoint file at {}'.format(cfg['cdc_checkpoint']))
        start_epoch = 0

    # Main loop
    print('Starting main loop', 'blue')
    best_acc = -1
    
    for epoch in range(start_epoch, cfg['max_epochs']):       
        print('Epoch %d/%d' % (epoch + 1, cfg['max_epochs']))
        # Train
        print('Train ...')
        train_cali(cfg, train_dataloader, cali_mlp, model, optimizer_cali, optimizer_clu, epoch, start_epoch)
        # Evaluate
        log_file = open(log_path, 'a')
        log_file.write(f'Epoch {epoch+1} - Validation prediction\n')
        
        if (epoch+1) % 1 == 0:
            print('Make prediction on validation set ...')
            predictions = get_predictions(cfg, val_dataloader, model)

            clustering_stats, indices = hungarian_evaluate_hard(cfg, cfg['cdc_dir'], epoch, 0, predictions,
                                                title=cfg['cluster_eval']['plot_title'],
                                                compute_confusion_matrix=False, _indices=indices)
            print('CDC-Clu ', clustering_stats)
            log_file.write(f'CDC-Clu: {clustering_stats}\n')
            
            predictions = get_predictions(cfg, val_dataloader, model, cali_mlp = cali_mlp)
            clustering_stats = calibration_evaluate(cfg, cfg['cdc_dir'], epoch, 0, predictions,
                                                title=cfg['cluster_eval']['plot_title'],
                                                compute_confusion_matrix=False)
            print('CDC-Cal ', clustering_stats)
            log_file.write(f'CDC-Cal: {clustering_stats}\n\n')
            
        print("train samples: ", len(train_dataloader.dataset))
        print("val samples: ", len(val_dataloader.dataset))    
            
        log_file.write(f'Train samples: {len(train_dataloader.dataset)}\n')
        log_file.write(f'Val samples: {len(val_dataloader.dataset)}\n\n') 
        log_file.close()    
        
        if epoch == prun_epoch:
            torch.save({'optimizer_clu': optimizer_clu.state_dict(),
                    'optimizer_cali': optimizer_cali.state_dict(),
                    'model': model.state_dict(),
                    'cali_mlp': cali_mlp.state_dict(),
                    'epoch': epoch},
                   os.path.join(cfg['cdc_dir'],f"checkpoint_{prun_epoch}.pth.tar"))
            print('prun sample ...')
            predictions = get_predictions(cfg, val_dataloader, model)
            clustering_stats, indices = hungarian_evaluate_hard(cfg, cfg['cdc_dir'], epoch, 0, predictions,
                                                title=cfg['cluster_eval']['plot_title'],
                                                compute_confusion_matrix=False, easy=True)
            predictions = get_predictions(cfg, val_dataloader, model, cali_mlp = cali_mlp)
            clustering_stats, indices = calibration_evaluate(cfg, cfg['cdc_dir'], epoch, 0, predictions,
                                                title=cfg['cluster_eval']['plot_title'],
                                                compute_confusion_matrix=False, flag=False)
            
            total_indices = list(set(indices))
            subset = Subset(train_dataset, total_indices)
            train_dataloader = get_train_dataloader(cfg, subset)
            
        # Checkpoint
        print('Checkpoint ...')
        torch.save({'optimizer_clu': optimizer_clu.state_dict(),
                    'optimizer_cali': optimizer_cali.state_dict(),
                    'model': model.state_dict(),
                    'cali_mlp': cali_mlp.state_dict(),
                    'epoch': epoch + 1},
                   cfg['cdc_checkpoint'])
        if best_acc < clustering_stats['ACC']:
            torch.save({
                        'model': model.state_dict(),
                        'cali_mlp': cali_mlp.state_dict(),
                        'epoch': epoch + 1},
                       cfg['cdc_best_model'])
            best_acc = clustering_stats['ACC']
        
    # Evaluate and save the final model
    print('Evaluate best model at the end')
    
    checkpoint = torch.load(cfg['cdc_best_model'], map_location='cpu')
    #import pdb; pdb.set_trace()
    model.load_state_dict(checkpoint['model'])
    cali_mlp.load_state_dict(checkpoint['cali_mlp'])
    
    predictions, features = get_predictions(cfg, val_dataloader, model, return_features=True)

    clustering_stats = hungarian_evaluate(cfg, cfg['cdc_dir'], cfg['max_epochs'], 0, predictions,
                                              title=cfg['cluster_eval']['plot_title'],
                            class_names=val_dataloader.dataset.classes,
                            compute_confusion_matrix=True,
                            confusion_matrix_file=os.path.join(cfg['cdc_dir'], 'confusion_matrix.png'), save_wrong=True)  
    print(clustering_stats)
    
    """ save_dict = {
        'features': features.cpu().numpy().astype(np.float32),
        'targets': predictions[0]['targets'].cpu().numpy().astype(np.int32),
        'predictions': predictions[0]['predictions'].cpu().numpy().astype(np.int32),
        'highlight': np.array(indices, dtype=np.int32)
    }
    save_file = os.path.join(cfg['cdc_dir'], f'hard_features_best_cdc_cifar10_1590.npz')
    np.savez_compressed(save_file, **save_dict)
    print(f"Features saved to: {save_file}")  """  
    
    log_file = open(log_path, 'a')
    log_file.write(f'best EVA: {clustering_stats}\n')
    log_file.close()
    
    """ generate_and_plot_tsne(cfg, val_dataloader, model, indices_to_highlight=indices, save_path=cfg['cdc_dir'], epoch=999) """
            
    if cfg['data']['dataset'] in ["cifar200"]:
        print("cifar-20-c")
        from cdc.data.cifar10c_dataset import get_cifar20c_dataloader, get_all_cifar20c_dataloaders
        cifar10c_root = '/nas/datasets/CIFAR-100-C'
    
        for i in range(1, 6):
            dataloader = get_cifar20c_dataloader(
                cifar10c_root, 
                corruption_type='pixelate', 
                severity=i, 
                batch_size=500,
                num_workers=cfg['data']['num_workers']
            )
            
            # Print dataset statistics
            print(f"Dataset size: {len(dataloader.dataset)}")
            predictions = get_predictions(cfg, dataloader, model)
            clustering_stats = hungarian_evaluate(cfg, cfg['cdc_dir'], 200, 0, predictions,
                                                title=cfg['cluster_eval']['plot_title'],
                                                compute_confusion_matrix=False)
            print('CDC-Clu ', clustering_stats)
            predictions = get_predictions(cfg, dataloader, model, cali_mlp = cali_mlp)
            clustering_stats = calibration_evaluate(cfg, cfg['cdc_dir'], 200, 0, predictions,
                                                title=cfg['cluster_eval']['plot_title'],
                                                compute_confusion_matrix=False)
            
            print('CDC-Cal ', clustering_stats)
            
    if cfg['data']['dataset'] in ["cifar200"]:
        print("cifar-20-auroc")
        labels = []
        scores = []
        
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2023, 0.1994, 0.2010]
        import torchvision.transforms as transforms
        
        id_dataset = get_val_dataset(cfg, val_transformations)
        id_dataloader = get_val_dataloader(cfg, id_dataset, 64)
        
        cfg_ = cfg
        cfg_['data']['dataset'] = 'tinyimagenet'
        cfg_['data']['val_path'] = '/nas/datasets/tiny-imagenet-200/val/images_new'
        ood_dataset = get_val_dataset(cfg_, val_transformations)
        ood_dataloader = get_val_dataloader(cfg_, ood_dataset, 64)

        #                    cifar10            cifar20    
        #                auroc    FPR95      auroc    FPR95
        # tinyimagenet: 
        # svhn:         
        
        model.eval()
        for i, batch in enumerate(id_dataloader):
            images, targets = batch[0], batch[1] 
            outputs = model(images.cuda(non_blocking=True),
                        forward_pass='return_all')['output'][0]
            probs = torch.softmax(outputs, dim=1)  # 获取 softmax 概率
            max_probs, _ = torch.max(probs, dim=1)  # 取最大类别的概率
            labels.extend([1] * len(targets))  # 全部标记为1
            scores.extend(max_probs.detach().cpu().numpy())  # 预测置信度
            
        for i, batch in enumerate(ood_dataloader):
            images, targets = batch[0], batch[1] 
            outputs = model(images.cuda(non_blocking=True),
                        forward_pass='return_all')['output'][0]
            probs = torch.softmax(outputs, dim=1)  # 获取 softmax 概率
            max_probs, _ = torch.max(probs, dim=1)  # 取最大类别的概率
            
            labels.extend([0] * len(targets))  # 全部标记为1
            scores.extend(max_probs.detach().cpu().numpy())  # 预测置信度
            
        # 计算 AUROC
        from sklearn.metrics import roc_auc_score, roc_curve
        auroc = roc_auc_score(labels, scores)
        # 计算 ROC 曲线
        fpr, tpr, thresholds = roc_curve(labels, scores)
        # 找到 TPR 最接近 95% 的索引
        target_tpr = 0.95
        idx = np.argmin(np.abs(tpr - target_tpr))
        # 对应的 FPR 值
        fpr95 = fpr[idx]
        print(f"FPR95: {fpr95:.4f}")
        print(f"AUROC Score: {auroc:.4f}")

    if cfg['data']['dataset'] in ["waterbirds"]:
        compute_worst_acc(model, val_dataloader,  cfg['method'])

if __name__ == "__main__":
    seed = 5
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.set_grad_enabled(True)
    print('seed:', seed)
    main()
    wandb.finish()