# This is a partial code of BiAL-CCL, mainly retaining the training process and BiAL parts.
# The complete code will be made public after the paper is accepted.
import argparse
import logging
import math
import os
import random
import shutil
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, SequentialSampler
from torch.utils.tensorboard import SummaryWriter
from dataset.cifar import DATASET_GETTERS
from utils import AverageMeter, accuracy
from utils import Logger
from progress.bar import Bar
import loss.semiConLoss as scl

logger = logging.getLogger(__name__)
best_acc = 0
best_acc_b = 0

# ---------------- Bias utilities (BiAL-style) ----------------
def _logmeanexp(x, dim=0):
    return torch.logsumexp(x, dim=dim) - math.log(x.size(dim))

@torch.no_grad()
def compute_bias_theta(model, args):
    was_training = model.training
    model.eval()
    B = 16
    C, H, W = 3, args.img_size, args.img_size
    noinfo = torch.zeros(B, C, H, W, device=args.device)

    out = model(noinfo)
    feat = out[0] if isinstance(out, tuple) else out  # ← 取特征
    z = model.classify(feat)
    zb = model.classify1(feat)
    z_co = 0.5 * z + 0.5 * zb
    b = _logmeanexp(z_co, dim=0)
    b = b - b.mean()

    if was_training:
        model.train()
    return b.detach()

def update_bias_theta(model, args, epoch):
    if (epoch >= args.bias_start_epoch) and ((epoch - args.bias_start_epoch) % args.bias_refresh_every == 0):
        b_now = compute_bias_theta(model, args)
        args.b_theta = args.bias_m * args.b_theta + (1.0 - args.bias_m) * b_now

def bias_ramp(epoch, warmup_epochs):
    if warmup_epochs <= 0:
        return 1.0
    t = max(0, min(epoch / float(warmup_epochs), 1.0))
    return t


def apply_bias(z, args):
    if getattr(args, 'b_theta', None) is None:
        return z
    return z - args.bias_beta_eff * args.b_theta


# ----------------------------------------------------------------
def compute_py(train_loader, args):
    label_freq = {}
    for i, (inputs, labell) in enumerate(train_loader):
        labell = labell.to(args.device)
        for j in labell:
            key = int(j.item())
            label_freq[key] = label_freq.get(key, 0) + 1
    label_freq = dict(sorted(label_freq.items()))
    label_freq_array = np.array(list(label_freq.values()))
    label_freq_array = label_freq_array / label_freq_array.sum()
    label_freq_array = torch.from_numpy(label_freq_array)
    label_freq_array = label_freq_array.to(args.device)
    return label_freq_array


def save_checkpoint(state, checkpoint, filename='checkpoint.pth.tar', epoch_p=1):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)


def set_seed(args):
    seed = args.seed
    if seed is not None:
        print(f"Deterministic with seed = {seed}")
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def get_cosine_schedule_with_warmup(optimizer,
                                    num_warmup_steps,
                                    num_training_steps,
                                    num_cycles=7. / 16.,
                                    last_epoch=-1):
    def _lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        no_progress = float(current_step - num_warmup_steps) / \
                      float(max(1, num_training_steps - num_warmup_steps))
        return max(0., math.cos(math.pi * num_cycles * no_progress))

    return LambdaLR(optimizer, _lr_lambda, last_epoch)


def compute_adjustment_by_py(py, tro, args):
    adjustments = torch.log(py ** tro + 1e-12)
    adjustments = adjustments.to(args.device)
    return adjustments


def sharp(a, T):
    a = a ** T
    a_sum = torch.sum(a, dim=1, keepdim=True)
    a = a / a_sum
    return a.detach()


def main():
    parser = argparse.ArgumentParser(description='PyTorch BiAL-CCL Training')
    parser.add_argument('--gpu-id', default='0', type=int,
                        help='id(s) for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--num-workers', type=int, default=1,
                        help='number of workers')
    parser.add_argument('--dataset', default='cifar10', type=str,
                        choices=['cifar10', 'cifar100', 'stl10', 'smallimagenet'],
                        help='dataset name')
    parser.add_argument('--num-labeled', type=int, default=4000,
                        help='number of labeled data')
    parser.add_argument('--arch', default='wideresnet', type=str,
                        choices=['wideresnet', 'resnet'],
                        help='dataset name')
    parser.add_argument('--total-steps', default=250000, type=int,
                        help='number of total steps to run')
    parser.add_argument('--eval-step', default=500, type=int,
                        help='number of eval steps to run')
    parser.add_argument('--start-epoch', default=0, type=int,
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--batch-size', default=64, type=int,
                        help='train batchsize')
    parser.add_argument('--lr', '--learning-rate', default=0.03, type=float,
                        help='initial learning rate')
    parser.add_argument('--warmup', default=0, type=float,
                        help='warmup epochs (unlabeled data based)')
    parser.add_argument('--wdecay', default=5e-4, type=float,
                        help='weight decay')
    parser.add_argument('--nesterov', action='store_true', default=True,
                        help='use nesterov momentum')
    parser.add_argument('--use-ema', action='store_true', default=True,
                        help='use EMA model')
    parser.add_argument('--ema-decay', default=0.999, type=float,
                        help='EMA decay rate')
    parser.add_argument('--mu', default=1, type=int,
                        help='coefficient of unlabeled batch size')
    parser.add_argument('--T', default=1, type=float,
                        help='pseudo label temperature')
    parser.add_argument('--threshold', default=0.90, type=float,
                        help='pseudo label threshold')
    parser.add_argument('--out', default='result',
                        help='directory to output the result')
    parser.add_argument('--resume', default='', type=str,
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--seed', default=0, type=int,
                        help="random seed")
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="For distributed training: local_rank")

    parser.add_argument('--num-max', default=500, type=int,
                        help='the max number of the labeled data')
    parser.add_argument('--num-max-u', default=4000, type=int,
                        help='the max number of the unlabeled data')
    parser.add_argument('--imb-ratio-label', default=1, type=int,
                        help='the imbalanced ratio of the labelled data')
    parser.add_argument('--imb-ratio-unlabel', default=1, type=int,
                        help='the imbalanced ratio of the unlabeled data')
    parser.add_argument('--flag-reverse-LT', default=0, type=int,
                        help='whether to reverse the distribution of the unlabeled data')
    parser.add_argument('--ema-mu', default=0.99, type=float,
                        help='mu when ema')

    parser.add_argument('--tau', default=2.0, type=float,
                        help='tau for head consistency')
    parser.add_argument('--est-epoch', default=10, type=int,
                        help='the start step to estimate the distribution')
    parser.add_argument('--img-size', default=32, type=int,
                        help='image size for small imagenet')
    parser.add_argument('--alpha', default=0.5, type=float,
                        help='ema ratio for estimating distribution of the unlabeled data')
    parser.add_argument('--beta', default=0.5, type=float,
                        help='ema ratio for estimating distribution of the all data')
    parser.add_argument('--lambda1', default=0.7, type=float,
                        help='coefficient of final loss')
    parser.add_argument('--lambda2', default=1.0, type=float,
                        help='coefficient of final loss')

    # ---- bias removal (BiAL) ----
    parser.add_argument('--bias-beta', type=float, default=1.0,
                        help='strength beta for bias subtraction E=z-beta*b_theta')
    parser.add_argument('--bias-m', type=float, default=0.9,
                        help='EMA momentum for b_theta')
    parser.add_argument('--bias-warmup-epochs', type=int, default=10,
                        help='epochs to ramp up beta from 0 to bias-beta')
    parser.add_argument('--bias-start-epoch', type=int, default=20,
                        help='epoch to start measuring bias')
    parser.add_argument('--bias-refresh-every', type=int, default=1,
                        help='refresh period (in epochs) for bias probing')
    parser.add_argument('--debiasstart', type=int, default=None,
                        help='Epoch to start training-time debias; default=bias_start_epoch')

    # The complete code will be made public after the paper is accepted




def train(args, labeled_trainloader, unlabeled_trainloader, test_loader,
          model, optimizer, ema_model, scheduler):
    global best_acc
    global best_acc_b
    test_accs = []
    avg_time = []
    end = time.time()
    if args.world_size > 1:
        labeled_epoch = 0
        unlabeled_epoch = 0
        labeled_trainloader.sampler.set_epoch(labeled_epoch)
        unlabeled_trainloader.sampler.set_epoch(unlabeled_epoch)
    logits_la_s = compute_adjustment_by_py(args.py_con, args.tau, args)
    labeled_iter = iter(labeled_trainloader)
    unlabeled_iter = iter(unlabeled_trainloader)
    semiConLoss = scl.SemiConLoss(args.batch_size, args.batch_size, args.num_classes, args)
    semiConLoss2 = scl.softConLoss(args.batch_size, args.batch_size, args.num_classes, args)
    model.train()
    lbs = args.batch_size
    ubs = args.batch_size * args.mu
    py_labeled = args.py_con.to(args.device)
    py_unlabeled = args.py_uni.to(args.device)
    py_all = args.py_all.to(args.device)
    cut1 = lbs + 3 * ubs
    pro = ubs / (ubs + lbs)
    for epoch in range(args.start_epoch, args.epochs):

        # refresh bias & update beta ramp
        update_bias_theta(model, args, epoch)
        args.bias_beta_eff = args.bias_beta * bias_ramp(epoch - args.bias_start_epoch, args.bias_warmup_epochs)

        print('current epoch: ', epoch + 1)
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        losses_con = AverageMeter()
        losses_cls = AverageMeter()
        losses_con2 = AverageMeter()

        bar = Bar('Training', max=args.eval_step)

        num_unlabeled = torch.ones(args.num_classes).to(args.device)
        num_all = torch.ones(args.num_classes).to(args.device)
        for batch_idx in range(args.eval_step):
            try:
                (inputs_x, inputs_x_s, inputs_x_s1), targets_x = next(labeled_iter)
            except:
                if args.world_size > 1:
                    labeled_epoch += 1
                    labeled_trainloader.sampler.set_epoch(labeled_epoch)
                labeled_iter = iter(labeled_trainloader)
                (inputs_x, inputs_x_s, inputs_x_s1), targets_x = next(labeled_iter)

            try:
                (inputs_u_w, inputs_u_s, inputs_u_s1), u_real = next(unlabeled_iter)
            except:
                if args.world_size > 1:
                    unlabeled_epoch += 1
                    unlabeled_trainloader.sampler.set_epoch(unlabeled_epoch)
                unlabeled_iter = iter(unlabeled_trainloader)
                (inputs_u_w, inputs_u_s, inputs_u_s1), u_real = next(unlabeled_iter)
            u_real = u_real.to(args.device)
            mask_l = (u_real != -2).float().unsqueeze(1).to(args.device)
            data_time.update(time.time() - end)
            inputs = torch.cat([inputs_x, inputs_u_w, inputs_u_s, inputs_u_s1, inputs_x_s, inputs_x_s1], dim=0).to(
                args.device)
            targets_x = targets_x.to(args.device)
            feat, feat_mlp, center_feat = model(inputs)
            # -----------------------------------------------------------------------------------------------------------
            logits = model.classify(feat[:cut1])  # [lbs + 3*ubs, C]
            logits_b = model.classify1(feat[:cut1])  # [lbs + 3*ubs, C]

            logits_x = logits[:lbs]
            logits_x_w, logits_x_s, logits_x_s1 = logits[lbs:].chunk(3)

            logits_x_b = logits_b[:lbs]
            logits_x_b_w, logits_x_b_s, logits_x_b_s1 = logits_b[lbs:].chunk(3)

            enable_sup = (epoch >= args.debiasstart)
            enable_unsup = (epoch >= args.debiasstart)

            if enable_sup:
                logits_x = apply_bias(logits_x, args)
                logits_x_b = apply_bias(logits_x_b, args)

            if enable_unsup:
                logits_x_w = apply_bias(logits_x_w, args)
                logits_x_s = apply_bias(logits_x_s, args)
                logits_x_s1 = apply_bias(logits_x_s1, args)
                logits_x_b_w = apply_bias(logits_x_b_w, args)
                logits_x_b_s = apply_bias(logits_x_b_s, args)
                logits_x_b_s1 = apply_bias(logits_x_b_s1, args)


            del logits, logits_b
            l_u_s = F.cross_entropy(logits_x, targets_x, reduction='mean')
            l_b_s = F.cross_entropy(logits_x_b + logits_la_s, targets_x, reduction='mean')
            logits_la_u = (- compute_adjustment_by_py((1 - pro) * py_labeled + pro * py_all, 1.0, args) +
                           compute_adjustment_by_py(py_unlabeled, 1 + args.tau / 2, args))

            logits_co = 1 / 2 * (logits_x_w + logits_la_u) + 1 / 2 * logits_x_b_w
            energy = -torch.logsumexp((logits_co.detach()) / args.T, dim=1)
            pseudo_label_co = F.softmax((logits_co.detach()) / args.T, dim=1)
            pseudo_label_con = sharp(F.softmax((logits_co.detach()) / args.T, dim=1), 4.0)

            prob_co, targets_co = torch.max(pseudo_label_co, dim=-1)
            mask = prob_co.ge(args.threshold)
            mask = mask.float()

            targets_co = torch.cat([targets_co, targets_co], dim=0).to(args.device)
            logits_b_s = torch.cat([logits_x_b_s, logits_x_b_s1], dim=0).to(args.device)
            logits_la_u_b = compute_adjustment_by_py(py_all, args.tau, args)
            mask_twice = torch.cat([mask, mask], dim=0)
            l_u_b = (F.cross_entropy(logits_b_s + logits_la_u_b, targets_co,
                                     reduction='none') * mask_twice).mean()

            logits_u_s = torch.cat([logits_x_s, logits_x_s1], dim=0).to(args.device)
            l_u_u = (F.cross_entropy(logits_u_s, targets_co,
                                     reduction='none') * mask_twice).mean()

            loss_u = max(1.5, args.mu) * l_u_u + l_u_s
            loss_b = max(1.5, args.mu) * l_u_b + l_b_s
            loss_cls = loss_u + loss_b
            # ----------------------------------------------------------------------------------------------------------
            feat_mlp = feat_mlp[lbs:]
            f3, f4 = feat_mlp[ubs:3 * ubs, :].chunk(2)
            f1, f2 = feat_mlp[3 * ubs:, :].chunk(2)

            # ----------------------------------------------------------------------------------------------------------
            feat_mlp = torch.cat([center_feat, feat_mlp[3 * ubs:, :], feat_mlp[:3 * ubs, :]], dim=0)
            center_label = torch.ones(args.num_classes, args.num_classes).to(args.device)
            one_hot_targets = F.one_hot(targets_x, num_classes=args.num_classes)
            one_hot_targets = torch.cat([one_hot_targets, one_hot_targets], dim=0).to(args.device)
            label_contrac = torch.cat([center_label, one_hot_targets], dim=0).to(args.device)
            # la = compute_adjustment_by_py(py_all, 1.0, args)
            contrac_loss = semiConLoss(feat_mlp, label_contrac)

            # ----------------------------------------------------------------------------------------------------------
            maskcon = energy.le(-8.75)
            idx = torch.nonzero(maskcon).squeeze()
            f3 = torch.reshape(f3[idx, :], (-1, f1.shape[1]))
            f4 = torch.reshape(f4[idx, :], (-1, f1.shape[1]))
            pseudo_label_con = torch.reshape(pseudo_label_con[idx, :], (-1, args.num_classes))

            label_contrac = torch.cat([center_label, one_hot_targets, pseudo_label_con, pseudo_label_con], dim=0).to(
                args.device)
            feat_all = torch.cat([center_feat, f1, f2, f3, f4], dim=0)
            contrac_loss2 = semiConLoss2(label_contrac, feat_all, args.device)

            loss = args.lambda1 * loss_cls + args.lambda2 * contrac_loss + (1 - args.lambda1) * contrac_loss2

            loss.backward()
            losses.update(loss.item())
            losses_cls.update(loss_cls.item())
            losses_con.update(contrac_loss.item())
            losses_con2.update(contrac_loss2.item())
            optimizer.step()
            scheduler.step()
            if args.use_ema:
                ema_model.update(model)
            model.zero_grad()

            mask = mask.unsqueeze(1).to(args.device)
            maskcon = maskcon.float().unsqueeze(1).to(args.device)
            num_all += torch.sum(pseudo_label_co * mask, dim=0)
            num_unlabeled += torch.sum(pseudo_label_co * maskcon, dim=0)
            # w_soft = mask.unsqueeze(1) * maskcon.unsqueeze(1).float()
            # num_unlabeled += torch.sum(pseudo_label_co * w_soft, dim=0)

            if (batch_idx + 1) % 100 == 0 and epoch > args.est_epoch:
                py_unlabeled = args.alpha * py_unlabeled + (1 - args.alpha) * num_unlabeled / sum(num_unlabeled)
                py_all = args.beta * py_all + (1 - args.beta) * num_all / sum(num_all)
                num_unlabeled = torch.ones(args.num_classes).to(args.device)
                num_all = torch.ones(args.num_classes).to(args.device)

            batch_time.update(time.time() - end)
            end = time.time()
            bar.suffix = '({batch}/{size}) | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \
                         'Loss: {loss:.4f} | Loss_cls: {loss_cls:.4f} | Loss_con: {loss_con:.4f} | Loss_con2: {loss_con2:.4f}'.format(
                batch=batch_idx + 1,
                size=args.eval_step,
                bt=batch_time.avg,
                total=bar.elapsed_td,
                eta=bar.eta_td,
                loss=losses.avg,
                loss_cls=losses_cls.avg,
                loss_con=losses_con.avg,
                loss_con2=losses_con2.avg,
            )
            bar.next()
        bar.finish()

        # if epoch > args.est_epoch:
        #     py_unlabeled = args.alpha * py_unlabeled + (1 - args.alpha) * num_unlabeled / sum(num_unlabeled)
        #     py_all = args.beta * py_all + (1 - args.beta) * num_all / sum(num_all)
        print('\n')
        print(py_unlabeled)
        print(py_all)
        avg_time.append(batch_time.avg)

        if args.use_ema:
            test_model = ema_model.ema
        else:
            test_model = model
        test_la = - compute_adjustment_by_py(1 / 2 * py_labeled + 1 / 2 * py_all, 1.0, args)
        if args.local_rank in [-1, 0]:

            test_loss, test_acc, test_top5_acc, test_acc_b, test_top5_acc_b = test(args, test_loader,
                                                                                   test_model, epoch,
                                                                                   test_la)
            args.writer.add_scalar('train/1.train_loss', losses.avg, epoch)
            args.writer.add_scalar('test/1.test_acc', test_acc_b, epoch)
            args.writer.add_scalar('test/2.test_loss', test_loss, epoch)

            best_acc = max(test_acc, best_acc)
            best_acc_b = max(test_acc_b, best_acc_b)

            model_to_save = model.module if hasattr(model, "module") else model
            if args.use_ema:
                ema_to_save = ema_model.ema.module if hasattr(
                    ema_model.ema, "module") else ema_model.ema

            if (epoch + 1) % 10 == 0:
                save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': model_to_save.state_dict(),
                    'ema_state_dict': ema_to_save.state_dict() if args.use_ema else None,
                    'acc': test_acc,
                    'best_acc': best_acc_b,
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'py_unlabeled': py_unlabeled,
                    'py_all': py_all
                }, args.out, epoch_p=epoch + 1)

            test_accs.append(test_acc_b)
            logger.info('Best top-1 acc: {:.2f}'.format(best_acc_b))
            logger.info('Mean top-1 acc: {:.2f}\n'.format(
                np.mean(test_accs[-20:])))

            args.logger.append([test_acc, test_top5_acc, best_acc, test_acc_b, test_top5_acc_b, best_acc_b])

    if args.local_rank in [-1, 0]:
        args.writer.close()


def test(args, test_loader, model, epoch, la):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    top1_b = AverageMeter()
    top5_b = AverageMeter()
    end = time.time()

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            data_time.update(time.time() - end)
            model.eval()

            inputs = inputs.to(args.device)
            targets = targets.to(args.device)
            # outputs_feat = model(inputs)
            # outputs = model.classify(outputs_feat)
            # outputs_b = model.classify1(outputs_feat)
            # outputs_co = 1 / 2 * (outputs + la) + 1 / 2 * outputs_b
            # loss = F.cross_entropy(outputs_b, targets)

            outputs_feat = model(inputs)
            outputs = model.classify(outputs_feat)
            outputs_b = model.classify1(outputs_feat)

            # debias at inference for train-test consistency
            outputs = apply_bias(outputs, args)
            outputs_b = apply_bias(outputs_b, args)

            outputs_co = 0.5 * (outputs + la) + 0.5 * outputs_b
            loss = F.cross_entropy(outputs_b, targets)

            prec1_b, prec5_b = accuracy(outputs_b, targets, topk=(1, 5))
            prec1_co, prec5_co = accuracy(outputs_co, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.shape[0])
            top1.update(prec1_co.item(), inputs.shape[0])
            top5.update(prec5_co.item(), inputs.shape[0])
            top1_b.update(prec1_b.item(), inputs.shape[0])
            top5_b.update(prec5_b.item(), inputs.shape[0])
            batch_time.update(time.time() - end)
            end = time.time()

    logger.info("top-1 acc: {:.2f}".format(top1_b.avg))
    logger.info("top-5 acc: {:.2f}".format(top5_b.avg))

    return losses.avg, top1.avg, top5.avg, top1_b.avg, top5_b.avg


if __name__ == '__main__':
    main()
