import argparse
import torch
from tqdm import tqdm
import data_loader.data_loaders as module_data
from data_loader.cifar_data_loaders import TestAgnosticImbalanceCIFAR100DataLoader
import model.model as module_arch
import numpy as np
from parse_config import ConfigParser
import torch.nn.functional as F

from utils import adjusted_model_wrapper
# ========== Ray Tune 相关 ==========
import ray
from ray import tune
from ray.tune.search.hyperopt import HyperOptSearch

import itertools
import copy
import numpy as np

def main(config, posthoc_bias_correction=False):
    logger = config.get_logger('test')

    # setup data_loader instances 
    if 'returns_feat' in config['arch']['args']:
        model = config.init_obj('arch', module_arch, allow_override=True, returns_feat=False, use_hnet=config['use_hnet'])
    else:
        model = config.init_obj('arch', module_arch, use_hnet=config['use_hnet'])
 
    logger.info('Loading checkpoint: {} ...'.format(config.resume))
    checkpoint = torch.load(config.resume)
    state_dict = checkpoint['state_dict']
    if config['n_gpu'] > 1:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(state_dict)

    # prepare model for testing
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()
  
    train_data_loader = getattr(module_data, config['data_loader']['type'])(
        config['data_loader']['args']['data_dir'],
        batch_size=128,
        shuffle=False,
        training=True,
        num_workers=8,
        imb_factor=config['data_loader']['args']['imb_factor']
    )    
    train_cls_num_list = train_data_loader.cls_num_list 
    train_cls_num_list=torch.tensor(train_cls_num_list)
    many_shot = train_cls_num_list > 100
    few_shot =train_cls_num_list <20
    medium_shot =~many_shot & ~few_shot 

    num_classes = config._config["arch"]["args"]["num_classes"]
    
    distrb = {
        'uniform': (0,False),
        'forward50': (0.02, False),
        'forward25': (0.04, False), 
        'forward10':(0.1, False),
        'forward5': (0.2, False),
        'forward2': (0.5, False),
        'backward50': (0.02, True),
        'backward25': (0.04, True),
        'backward10': (0.1, True),
        'backward5': (0.2, True),
        'backward2': (0.5, True),
    }  
    
    record_list=[]
    
    test_distribution_set = ["forward50",  "forward25", "forward10", "forward5", "forward2", "uniform",  "backward2", "backward5", "backward10", "backward25", "backward50"] 
    for test_distribution in test_distribution_set: 
        print(test_distribution)
        data_loader = TestAgnosticImbalanceCIFAR100DataLoader(
            config['data_loader']['args']['data_dir'],
            batch_size=128,
            shuffle=False,
            training=False,
            num_workers=2,
            test_imb_factor=distrb[test_distribution][0],
            reverse=distrb[test_distribution][1]
        )

        if posthoc_bias_correction:
            test_prior = torch.tensor(data_loader.cls_num_list).float().to(device)
            test_prior = test_prior / test_prior.sum()
            test_bias = test_prior.log()
        else:
            test_bias = None

        adjusted_model = adjusted_model_wrapper(model, test_bias=test_bias)

        weight = [config['aggregation_weight1'], config['aggregation_weight2'], config['aggregation_weight3']]
        
        record = validation(data_loader, adjusted_model, num_classes,device, many_shot, medium_shot, few_shot, weight)
            
        record_list.append(record)
    print('='*25, ' Final results ', '='*25)
    i = 0
    for txt in record_list:
        print(test_distribution_set[i]+'\t')
        print(*txt)          
        i+=1
    
    # 提取所有分布的 eval_acc_mic_top1（即返回结果的最后一个值）
    eval_acc_mic_top1_list = [record[3] for record in record_list]
    # 计算平均值
    average_eval_acc_mic_top1 = np.mean(eval_acc_mic_top1_list)
    # print(average_eval_acc_mic_top1)
    return average_eval_acc_mic_top1, record_list    

def mic_acc_cal(preds, labels):
    if isinstance(labels, tuple):
        assert len(labels) == 3
        targets_a, targets_b, lam = labels
        acc_mic_top1 = (lam * preds.eq(targets_a.data).cpu().sum().float() \
                       + (1 - lam) * preds.eq(targets_b.data).cpu().sum().float()) / len(preds)
    else:
        acc_mic_top1 = (preds == labels).sum().item() / len(labels)
    return acc_mic_top1
   

def validation(data_loader, model, num_classes,device,many_shot, medium_shot, few_shot, weight):
 
    confusion_matrix = torch.zeros(num_classes, num_classes).cuda()
    total_logits = torch.empty((0, num_classes)).cuda()
    total_labels = torch.empty(0, dtype=torch.long).cuda()
    with torch.no_grad():
        for i, (data, target) in enumerate(tqdm(data_loader)):
            data, target = data.to(device), target.to(device)
            weight = torch.tensor(weight).cuda()
            output = model(data, ray=weight)
            for t, p in zip(target.view(-1), output.argmax(dim=1).view(-1)):
                confusion_matrix[t.long(), p.long()] += 1
            total_logits = torch.cat((total_logits, output))
            total_labels = torch.cat((total_labels, target))  
             
    probs, preds = F.softmax(total_logits.detach(), dim=1).max(dim=1)

    # Calculate the overall accuracy and F measurement
    eval_acc_mic_top1= mic_acc_cal(preds[total_labels != -1],
                                        total_labels[total_labels != -1])
         
    acc_per_class = confusion_matrix.diag()/confusion_matrix.sum(1)
    acc = acc_per_class.cpu().numpy() 
    many_shot_acc = acc[many_shot].mean()
    medium_shot_acc = acc[medium_shot].mean()
    few_shot_acc = acc[few_shot].mean()
    print("Many-shot {}, Medium-shot {}, Few-shot {}, All {}".format(np.round(many_shot_acc * 100, decimals=2), np.round(medium_shot_acc * 100, decimals=2), np.round(few_shot_acc * 100, decimals=2), np.round(eval_acc_mic_top1 * 100, decimals=2)))
    return np.round(many_shot_acc * 100, decimals=2), np.round(medium_shot_acc * 100, decimals=2), np.round(few_shot_acc * 100, decimals=2), np.round(eval_acc_mic_top1 * 100, decimals=2)
 

if __name__ == '__main__':
    args = argparse.ArgumentParser(description='PyTorch Template')
    args.add_argument('-c', '--config', default=None, type=str,
                      help='config file path (default: None)')
    args.add_argument('-r', '--resume', default=None, type=str,
                      help='path to latest checkpoint (default: None)')
    args.add_argument('-d', '--device', default=None, type=str,
                      help='indices of GPUs to enable (default: all)')
    args.add_argument('-l', '--log-config', default='logger/logger_config.json', type=str,
                      help='logging config file path (default: logger/logger_config.json)')
    args.add_argument("--posthoc_bias_correction", dest="posthoc_bias_correction", action="store_true", default=False)

    # dummy arguments used during training time
    args.add_argument("--validate")
    args.add_argument("--use-wandb")

    config, args = ConfigParser.from_args(args)
    
    # 生成 0.0 到 1.0 之间，间隔 0.1 的值
    param_values = [round(x * 0.1, 1) for x in range(11)]  # [0.0, 0.1, 0.2, ..., 1.0]

    best_config = None
    best_score = -np.inf
    results = []

    # 枚举所有 3 个超参数的组合
    for w1, w2, w3 in itertools.product(param_values, repeat=3):
        # 对原始 config 拷贝一份，避免不同组合相互影响
        config_copy = copy.deepcopy(config)
        # 假设在 config 的字典中超参数的位置是：
        # config._config['aggregation_weight1'], ['aggregation_weight2'], ['aggregation_weight3']
        config_copy._config['aggregation_weight1'] = w1
        config_copy._config['aggregation_weight2'] = w2
        config_copy._config['aggregation_weight3'] = w3

        # 调用测试函数，获得当前组合下的平均评测准确率
        score, record_list = main(config_copy, args.posthoc_bias_correction)
        results.append((w1, w2, w3, score, record_list))
        print(f"测试组合: aggregation_weight1={w1}, aggregation_weight2={w2}, aggregation_weight3={w3} --> 得分: {score:.4f}")
    
        if score > best_score:
            best_score = score
            best_config = (w1, w2, w3)

    print("="*30, "最终结果", "="*30)
    print(f"最佳超参数组合: aggregation_weight1={best_config[0]}, aggregation_weight2={best_config[1]}, aggregation_weight3={best_config[2]}, 得分: {best_score:.4f}")

    # 输出所有组合的结果
    
    print("\n全部组合结果:")
    for item in results:
        w1, w2, w3, score, record_list = item
        print(f"组合: w1={w1}, w2={w2}, w3={w3}  得分: {score:.4f}")
        print("详细性能记录:")
        for rec in record_list:
            print(rec)
        print("-" * 60)
    # main(config, args.posthoc_bias_correction)
