import os
import argparse
import torch
import numpy as np
import logging
import sys
import json
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader

from dataset import AVEDataset
from model import MultiViewNet
from fusion import UnifiedFusion 


logging.basicConfig(level=logging.INFO, format='%(message)s', handlers=[logging.StreamHandler(sys.stdout)])
logger = logging.getLogger(__name__)

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='ave')
    parser.add_argument('--data_root', type=str, default='./data')
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--checkpoint_dir', type=str, default='./kfold_checkpoints')
    parser.add_argument('--folds', type=int, default=5, help='Number of folds to evaluate (default: 5)')
    parser.add_argument('--save_dir', type=str, default='./results', help='Directory to save result json')
    return parser.parse_args()

def calculate_metrics(preds, targets, probs, num_classes):

    correct = (preds == targets).sum().item()
    total = targets.size(0)
    acc = 100. * correct / total

    try:
        targets_np = targets.cpu().numpy()
        probs_np = probs.cpu().numpy()
        if num_classes == 2:
            auroc = roc_auc_score(targets_np, probs_np[:, 1])
        else:
            auroc = roc_auc_score(targets_np, probs_np, multi_class='ovr', average='macro')
    except Exception:
        auroc = 0.5 

    return acc, auroc * 100.

def verify_scale_bias_single_fold(model, loader, scale_factors, methods, device, num_classes):

    fusers = {m: UnifiedFusion(m).to(device) for m in methods}
    for m in methods:
        fusers[m].eval()

    results = {T: {m: {} for m in methods} for T in scale_factors}
    
    model.eval()

    for T in tqdm(scale_factors, desc="Evaluating Scales", leave=False):
        

        all_preds = {m: [] for m in methods}
        all_probs = {m: [] for m in methods}
        all_targets = []

        with torch.no_grad():
            for batch in loader:

                for k, v in batch.items():
                    if isinstance(v, torch.Tensor): batch[k] = v.to(device)
                target = batch['target']
                all_targets.append(target)
                
                logits_list = model(batch) 
                
                
                logits_biased = []
                
                logits_biased.append(logits_list[0] / T)
                
                logits_biased.extend(logits_list[1:]) 

                for m in methods:
                    output = fusers[m](logits_biased)
                    
                    probs = output[0] if isinstance(output, tuple) else output
                    preds = torch.argmax(probs, dim=1)
                    
                    all_preds[m].append(preds)
                    all_probs[m].append(probs)

        full_targets = torch.cat(all_targets)
        for m in methods:
            full_preds = torch.cat(all_preds[m])
            full_probs = torch.cat(all_probs[m])
            
            acc, auc = calculate_metrics(full_preds, full_targets, full_probs, num_classes)
            results[T][m]['acc'] = acc
            results[T][m]['auc'] = auc

    return results

def main():
    args = get_args()
    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    scale_factors = [0.01, 0.05, 0.1, 0.5, 1.0, 5.0, 10.0, 50.0, 100.0]
    methods = ['sum', 'mean', 'ds', 'weighted', 'ours']
    
    if args.dataset == 'ave': num_classes = 28
    elif args.dataset == 'sunrgbd': num_classes = 15
    elif args.dataset == 'chexpert': num_classes = 7
    elif args.dataset == 'mura': num_classes = 2
    else: raise ValueError(f"Unknown dataset: {args.dataset}")

    global_stats = {
        T: {m: {'acc': [], 'auc': []} for m in methods}
        for T in scale_factors
    }

    logger.info(f"Starting Optimized K-Fold Scale Bias Verification")
    logger.info(f"Dataset: {args.dataset} | Classes: {num_classes} | Folds: {args.folds}")
    logger.info("-" * 60)

    for fold in range(1, args.folds + 1):
        logger.info(f"[Processing Fold {fold}/{args.folds}]")
        
        if args.dataset == 'ave':
            test_ds = AVEDataset(args.data_root, fold=fold, split='test')
            
        loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)
        
        model = MultiViewNet(args.dataset, num_classes).to(device)
        ckpt_path = os.path.join(args.checkpoint_dir, f"{args.dataset}_fold{fold}_best.pth")
        
        if os.path.exists(ckpt_path):
            model.load_state_dict(torch.load(ckpt_path), strict=False)
        else:
            logger.warning(f"Checkpoint {ckpt_path} not found! Skipping this fold.")
            continue 

        fold_results = verify_scale_bias_single_fold(model, loader, scale_factors, methods, device, num_classes)
        
        for T in scale_factors:
            for m in methods:
                global_stats[T][m]['acc'].append(fold_results[T][m]['acc'])
                global_stats[T][m]['auc'].append(fold_results[T][m]['auc'])
        
        logger.info(f"  > Fold {fold} Summary (Ours): T=1.0 ACC={fold_results[1.0]['ours']['acc']:.2f} | T=100.0 ACC={fold_results[100.0]['ours']['acc']:.2f}")

    
    final_output = {
        "dataset": args.dataset,
        "results": {}
    }

    def print_and_save_table(metric_key, metric_name):
        print("\n" + "="*110)
        print(f"{metric_name} Results - {args.dataset.upper()} (Avg of {args.folds} Folds)")
        print("="*110)
        
        header = f"{'Scale (T)':<10} | " + " | ".join([f"{m.upper():<14}" for m in methods])
        print(header)
        print("-" * 110)
        
        metric_results = {} 

        for T in scale_factors:
            row = f"{T:<10.2g} | "
            metric_results[str(T)] = {}
            
            for m in methods:
                values = global_stats[T][m][metric_key]
                if len(values) > 0:
                    mean_val = np.mean(values)
                    std_val = np.std(values)
                    row += f"{mean_val:.1f}±{std_val:.1f}      | "
                    metric_results[str(T)][m] = {"mean": float(mean_val), "std": float(std_val)}
                else:
                    row += f"{'N/A':<14} | "
                    metric_results[str(T)][m] = None
            print(row)
        print("="*110)
        
        final_output["results"][metric_key] = metric_results


    print_and_save_table('auc', 'AUROC (%)')
    
    print_and_save_table('acc', 'ACCURACY (%)')

    save_path = os.path.join(args.save_dir, f"{args.dataset}_scale_robustness.json")
    with open(save_path, 'w') as f:
        json.dump(final_output, f, indent=4)
    
    logger.info(f"\nResults saved to {save_path}")

if __name__ == '__main__':
    main()