import argparse
import torch
from tqdm import tqdm
import data_loader.data_loaders as module_data
from data_loader.cifar_data_loaders import TestAgnosticImbalanceCIFAR100DataLoader
import model.model as module_arch
import numpy as np
from parse_config import ConfigParser
import torch.nn.functional as F

from utils import adjusted_model_wrapper
# ========== Ray Tune 相关 ==========
import ray
from ray import tune
from ray.tune.search.hyperopt import HyperOptSearch

import itertools
import copy
import numpy as np

def main(config, posthoc_bias_correction=False):
    logger = config.get_logger('test')

    # setup data_loader instances 
    if 'returns_feat' in config['arch']['args']:
        model = config.init_obj('arch', module_arch, allow_override=True, returns_feat=False, use_hnet=config['use_hnet'])
    else:
        model = config.init_obj('arch', module_arch, use_hnet=config['use_hnet'])
 
    logger.info('Loading checkpoint: {} ...'.format(config.resume))
    checkpoint = torch.load(config.resume)
    state_dict = checkpoint['state_dict']
    if config['n_gpu'] > 1:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(state_dict)

    # prepare model for testing
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()
  
    train_data_loader = getattr(module_data, config['data_loader']['type'])(
        config['data_loader']['args']['data_dir'],
        batch_size=128,
        shuffle=False,
        training=True,
        num_workers=8,
        imb_factor=config['data_loader']['args']['imb_factor']
    )    
    train_cls_num_list = train_data_loader.cls_num_list 
    train_cls_num_list=torch.tensor(train_cls_num_list)
    many_shot = train_cls_num_list > 100
    few_shot =train_cls_num_list <20
    medium_shot =~many_shot & ~few_shot 

    num_classes = config._config["arch"]["args"]["num_classes"]
    
    distrb = {
        'uniform': (0,False),
        'forward50': (0.02, False),
        'forward25': (0.04, False), 
        'forward10':(0.1, False),
        'forward5': (0.2, False),
        'forward2': (0.5, False),
        'backward50': (0.02, True),
        'backward25': (0.04, True),
        'backward10': (0.1, True),
        'backward5': (0.2, True),
        'backward2': (0.5, True),
    }  
    
    test_distribution_set = ["forward50",  "forward25", "forward10", "forward5", "forward2", "uniform",  "backward2", "backward5", "backward10", "backward25", "backward50"] 
    # 生成 0.0 到 1.0 之间，间隔 0.1 的值
    param_values = [round(x * 0.2, 1) for x in range(6)]
    best_configs = [[] for _ in range(11)]

    # 全局去重：已经“上榜”的参数组合
    seen_configs = set()
    # 外层先遍历各个分布
    for dist_idx, test_distribution in enumerate(test_distribution_set):
        # 针对当前分布的 DataLoader 和 bias
        print(test_distribution)
        data_loader = TestAgnosticImbalanceCIFAR100DataLoader(
            config['data_loader']['args']['data_dir'],
            batch_size=128,
            shuffle=False,
            training=False,
            num_workers=2,
            test_imb_factor=distrb[test_distribution][0],
            reverse=distrb[test_distribution][1]
        )
        if posthoc_bias_correction:
            test_prior = torch.tensor(data_loader.cls_num_list).float().to(device)
            test_prior = test_prior / test_prior.sum()
            test_bias = test_prior.log()
        else:
            test_bias = None
        adjusted_model = adjusted_model_wrapper(model, test_bias=test_bias)

        # 内层遍历所有 (w1,w2,w3)
        for w1, w2, w3 in itertools.product(param_values, repeat=3):
            config_tuple = (w1, w2, w3)
            # 如果这个组合已经在之前任何一个分布的 top-list 中出现过，直接跳过
            if config_tuple in seen_configs:
                continue

            # 否则做一次验证
            weight = [w1, w2, w3]
            record = validation(data_loader, adjusted_model, num_classes, device, many_shot, medium_shot, few_shot, weight)
            score = record[3]  # 第四个元素作为评分

            # 插入当前分布的 best_configs，并截断保留 top 100
            best_configs[dist_idx].append((score, list(config_tuple)))
            best_configs[dist_idx].sort(key=lambda x: x[0], reverse=True)
            if len(best_configs[dist_idx]) > 10:
                best_configs[dist_idx] = best_configs[dist_idx][:10]

        # 把本分布最终入榜的所有参数加入 seen_configs，避免后续分布重复测试
        for _, cfg in best_configs[dist_idx]:
            seen_configs.add(tuple(cfg))

    # 打印结果
    print("="*30, "最终结果", "="*30)
    for i, dist in enumerate(test_distribution_set):
        print(f"分布：{dist}")
        for rank, (score, cfg) in enumerate(best_configs[i], start=1):
            print(f"  Top {rank}: w1={cfg[0]}, w2={cfg[1]}, w3={cfg[2]}  得分: {score:.4f}")
        print("-"*60)
    
    for i in range(11):
        print(f"分布：{test_distribution_set[i]}")
        # 输出每个分布下得分最高的前 10 组参数
        for rank, (score, config_params) in enumerate(best_configs[i], 1):
            print(f" [{config_params[0]}, {config_params[1]}, {config_params[2]}],")
        print("-" * 60)

    # # 枚举所有 3 个超参数的组合
    # for w1, w2, w3 in itertools.product(param_values, repeat=3):
    #     record_list = []
    #     # 对每个分布测试当前参数组合的表现
    #     for test_distribution in test_distribution_set:
    #         print(test_distribution)
    #         data_loader = TestAgnosticImbalanceCIFAR100DataLoader(
    #             config['data_loader']['args']['data_dir'],
    #             batch_size=128,
    #             shuffle=False,
    #             training=False,
    #             num_workers=2,
    #             test_imb_factor=distrb[test_distribution][0],
    #             reverse=distrb[test_distribution][1]
    #         )
    #         if posthoc_bias_correction:
    #             test_prior = torch.tensor(data_loader.cls_num_list).float().to(device)
    #             test_prior = test_prior / test_prior.sum()
    #             test_bias = test_prior.log()
    #         else:
    #             test_bias = None
    #         adjusted_model = adjusted_model_wrapper(model, test_bias=test_bias)
    #         weight = [w1, w2, w3]
    #         # validation 返回的 record 是一个列表，其中第 4 个元素 record[3] 表示我们关注的分数
    #         record = validation(data_loader, adjusted_model, num_classes, device, many_shot, medium_shot, few_shot, weight)
    #         record_list.append(record)
    #     print('='*25, ' Final results ', '='*25)
    #     for i, txt in enumerate(record_list):
    #         print(test_distribution_set[i]+'\t')
    #         print(*txt)

    #     # 对每个分布，更新保存当前参数组合的得分和参数
    #     for i in range(11):
    #         score = record_list[i][3]  # 例如，第 4 个元素作为评分指标
    #         config_current = [w1, w2, w3]
    #         best_configs[i].append((score, config_current))
    #         # 排序：按得分降序排列，并保留前 10 个
    #         best_configs[i] = sorted(best_configs[i], key=lambda x: x[0], reverse=True)[:100]

    # print("="*30, "最终结果", "="*30)
    # for i in range(11):
    #     print(f"分布：{test_distribution_set[i]}")
    #     # 输出每个分布下得分最高的前 10 组参数
    #     for rank, (score, config_params) in enumerate(best_configs[i], 1):
    #         print(f"  Top {rank}: w1={config_params[0]}, w2={config_params[1]}, w3={config_params[2]}  得分: {score:.4f}")
    #     print("-" * 60)
    
    # for i in range(11):
    #     print(f"分布：{test_distribution_set[i]}")
    #     # 输出每个分布下得分最高的前 10 组参数
    #     for rank, (score, config_params) in enumerate(best_configs[i], 1):
    #         print(f" [{config_params[0]}, {config_params[1]}, {config_params[2]}],")
    #     print("-" * 60)



def mic_acc_cal(preds, labels):
    if isinstance(labels, tuple):
        assert len(labels) == 3
        targets_a, targets_b, lam = labels
        acc_mic_top1 = (lam * preds.eq(targets_a.data).cpu().sum().float() \
                       + (1 - lam) * preds.eq(targets_b.data).cpu().sum().float()) / len(preds)
    else:
        acc_mic_top1 = (preds == labels).sum().item() / len(labels)
    return acc_mic_top1
   

def validation(data_loader, model, num_classes,device,many_shot, medium_shot, few_shot, weight):
 
    confusion_matrix = torch.zeros(num_classes, num_classes).cuda()
    total_logits = torch.empty((0, num_classes)).cuda()
    total_labels = torch.empty(0, dtype=torch.long).cuda()
    with torch.no_grad():
        for i, (data, target) in enumerate(tqdm(data_loader)):
            data, target = data.to(device), target.to(device)
            weight = torch.tensor(weight).cuda()
            output = model(data, ray=weight)
            for t, p in zip(target.view(-1), output.argmax(dim=1).view(-1)):
                confusion_matrix[t.long(), p.long()] += 1
            total_logits = torch.cat((total_logits, output))
            total_labels = torch.cat((total_labels, target))  
             
    probs, preds = F.softmax(total_logits.detach(), dim=1).max(dim=1)

    # Calculate the overall accuracy and F measurement
    eval_acc_mic_top1= mic_acc_cal(preds[total_labels != -1],
                                        total_labels[total_labels != -1])
         
    acc_per_class = confusion_matrix.diag()/confusion_matrix.sum(1)
    acc = acc_per_class.cpu().numpy() 
    many_shot_acc = acc[many_shot].mean()
    medium_shot_acc = acc[medium_shot].mean()
    few_shot_acc = acc[few_shot].mean()
    print("Many-shot {}, Medium-shot {}, Few-shot {}, All {}".format(np.round(many_shot_acc * 100, decimals=2), np.round(medium_shot_acc * 100, decimals=2), np.round(few_shot_acc * 100, decimals=2), np.round(eval_acc_mic_top1 * 100, decimals=2)))
    return np.round(many_shot_acc * 100, decimals=2), np.round(medium_shot_acc * 100, decimals=2), np.round(few_shot_acc * 100, decimals=2), np.round(eval_acc_mic_top1 * 100, decimals=2)
 

if __name__ == '__main__':
    args = argparse.ArgumentParser(description='PyTorch Template')
    args.add_argument('-c', '--config', default=None, type=str,
                      help='config file path (default: None)')
    args.add_argument('-r', '--resume', default=None, type=str,
                      help='path to latest checkpoint (default: None)')
    args.add_argument('-d', '--device', default=None, type=str,
                      help='indices of GPUs to enable (default: all)')
    args.add_argument('-l', '--log-config', default='logger/logger_config.json', type=str,
                      help='logging config file path (default: logger/logger_config.json)')
    args.add_argument("--posthoc_bias_correction", dest="posthoc_bias_correction", action="store_true", default=False)

    # dummy arguments used during training time
    args.add_argument("--validate")
    args.add_argument("--use-wandb")

    config, args = ConfigParser.from_args(args)
    main(config, args.posthoc_bias_correction)