import os
import random
import pyrallis
import json
from tqdm import tqdm

import numpy as np
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter

from loaders.image_loader import load_images
from models.base import load_model
from optimization.optimizer import load_optimizer
from optimization.scheduler import load_lr_scheduler
from loss.task_loss import load_task_loss
from metrics.accuracy import accuracy
from utils.train_util import AverageMeter, ProgressMeter
from utils.log_utils import log_scalar_dict, create_experiment_dir

from options import ModelBaseTrainConfig

def main():
    # ----------------------------------------
    # basic configuration
    # ----------------------------------------
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # ----------------------------------------
    # model base configuration
    # ----------------------------------------
    model = load_model('vit_b_16', num_classes=1000).to(device)
    model.load_state_dict(torch.load('/nfs196/wjx/projects/PMP/outputs/LT/model.pt', map_location=device, weights_only=False))
        
    # ----------------------------------------
    # data loader configuration
    # ----------------------------------------
    train_loader = load_images('/nfs196/wjx/datasets/Imagenet-lt', 'Imagenet-lt', data_type='train', batch_size=1024, path_prefix="/nfs196/hjc/datasets/ILSVRC2012")
    test_loader = load_images('/nfs196/wjx/datasets/Imagenet-lt', 'Imagenet-lt', data_type='test', batch_size=1024, path_prefix="/nfs196/hjc/datasets/ILSVRC2012")
    
    # ----------------------------------------
    # each epoch
    # ----------------------------------------
    head_acc, mid_acc, tail_acc = test_lt(train_loader, test_loader, model, device)
    
    print('HEAD ACC', head_acc)
    print('MID ACC', mid_acc)
    print('TAIL ACC', tail_acc)



def test_lt(train_loader, test_loader, model, device):
    """
    直接从 train_loader.dataset 中获取标签，在训练集上统计每类样本数，
    然后按阈值划分头/中/尾类，最后在测试集上分别计算准确率。
    """
    model.eval()

    # --------------------------------------------
    # 1. 直接从 train_loader.dataset 获取所有训练集标签
    # --------------------------------------------
    dataset = train_loader.dataset

    try:
        raw_samples = dataset.samples  # [('path/to/img1.jpg', 3), ('path/to/img2.jpg', 0), …]
    except AttributeError:
        raise RuntimeError(
            "找不到 `dataset.samples` 属性，请确认您的 Dataset 定义中包含 samples 列表，"
            "其元素为 (image_path, target) 形式。"
        )

    # 提取所有 target（标签）
    # raw_samples 列表长度 = 训练集样本总数
    train_labels = torch.tensor(
        [target for (_, target) in raw_samples],
        dtype=torch.long,
    )

    # 统计训练集中每个类别的样本数
    train_counts = torch.bincount(train_labels)

    # 按阈值划分 head/mid/tail
    head_classes = torch.nonzero(train_counts > 100, as_tuple=False).view(-1)
    mid_classes  = torch.nonzero((train_counts >= 20) & (train_counts <= 100), as_tuple=False).view(-1)
    tail_classes = torch.nonzero(train_counts < 20, as_tuple=False).view(-1)
    
    total_train_samples = train_labels.numel()
    head_sample_count = train_counts[head_classes].sum().item() if head_classes.numel() > 0 else 0
    mid_sample_count  = train_counts[mid_classes].sum().item() if mid_classes.numel()  > 0 else 0
    tail_sample_count = train_counts[tail_classes].sum().item() if tail_classes.numel() > 0 else 0

    head_ratio = head_sample_count / total_train_samples
    mid_ratio  = mid_sample_count  / total_train_samples
    tail_ratio = tail_sample_count / total_train_samples

    print("=== 训练集样本分布占比 ===")
    print(f" 头部类（>100）样本数：{head_sample_count}，占比：{head_ratio:.2%}")
    print(f" 中部类（20~100）样本数：{mid_sample_count}，占比：{mid_ratio:.2%}")
    print(f" 尾部类（<20）样本数：{tail_sample_count}，占比：{tail_ratio:.2%}")
    
    print("统计结束")
    # --------------------------------------------
    # 2. 在测试集上做预测并收集结果
    # --------------------------------------------
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            preds = outputs.argmax(dim=1)

            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    # --------------------------------------------
    # 3. 计算头/中/尾 三组的准确率
    # --------------------------------------------
    head_mask = torch.isin(all_labels, head_classes)
    mid_mask  = torch.isin(all_labels, mid_classes)
    tail_mask = torch.isin(all_labels, tail_classes)

    correct = (all_preds == all_labels)

    head_total   = head_mask.sum().item()
    head_correct = correct[head_mask].sum().item()
    head_acc = head_correct / head_total if head_total > 0 else 0.0

    mid_total   = mid_mask.sum().item()
    mid_correct = correct[mid_mask].sum().item()
    mid_acc = mid_correct / mid_total if mid_total > 0 else 0.0

    tail_total   = tail_mask.sum().item()
    tail_correct = correct[tail_mask].sum().item()
    tail_acc = tail_correct / tail_total if tail_total > 0 else 0.0

    return head_acc, mid_acc, tail_acc



def init_seed(seed = 2025):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    

if __name__ == '__main__':
    main()
