import argparse
import math
import os
import random
import shutil
import time
import torch
import torch.nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.utils.data
import numpy as np
from torch.distributions import binomial
from utils.utils_networks import resnet18 as resnet18, resnet34, resnet50, convnet
from utils.utils_algo import *
from utils.utils_data import *
from utils.utils_loss import *
import higher

torch.set_printoptions(precision=2, sci_mode=False)

parser = argparse.ArgumentParser(description='PyTorch Implementation of CLAPOR using higher library')
parser.add_argument('--dataset', default='cifar10', type=str,
                    choices=['cifar10', 'cifar100'],
                    help='dataset name')
parser.add_argument('--exp-dir', default='experiment/CIFAR-10', type=str,
                    help='experiment directory for saving checkpoints and logs')
parser.add_argument('--data-dir', default='data/pre-processed-data', type=str,
                    help='data directory for loading preprocessed data')
parser.add_argument('-j', '--workers', default=32, type=int,
                    help='number of data loading workers (default: 32)')
parser.add_argument('--epochs', default=1000, type=int,
                    help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('-lr_decay_epochs', type=str, default='700,800,900',
                    help='where to decay lr, can be a list')
parser.add_argument('-lr_decay_rate', type=float, default=0.1,
                    help='decay rate for learning rate')
parser.add_argument('--cosine', action='store_true', default=False,
                    help='use cosine lr schedule')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum for SGD optimizer')
parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
                    metavar='W', help='weight decay (default: 1e-3)',
                    dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=100, type=int,
                    help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training')
parser.add_argument('--gpu', default=0, type=int,
                    help='GPU id to use')
parser.add_argument('--num-class', default=10, type=int,
                    help='number of classes')
parser.add_argument('--queue_length', default=64, type=int,
                    help='queue size as queue_length * batch_size')
parser.add_argument('--lamd', default=3, type=float,
                    help='parameter for Sinkhorn algorithm')
parser.add_argument('--eta', default=0.9, type=float,
                    help='final weight for renormalized loss')
parser.add_argument('--tau', default=0.99, type=float,
                    help='high confidence selection threshold')
parser.add_argument('--rho_range', default='0.2,0.8', type=str,
                    help='clean label proportion (rho)')
parser.add_argument('--gamma', default='0.1,0.01', type=str,
                    help='distribution distillation parameter')
parser.add_argument('--warmup_epoch', default=50, type=int,
                    help='warmup epochs for unreliable samples')
parser.add_argument('--est_epochs', default=20, type=int,
                    help='epochs for estimating class prior')
parser.add_argument('--partial_rate', default=0.1, type=float,
                    help='ambiguity level (phi)')
parser.add_argument('--hierarchical', action='store_true',
                    help='for CIFAR-100 fine-grained training')
parser.add_argument('--imb_type', default='exp', choices=['exp', 'step'],
                    help='imbalance data type')
parser.add_argument('--imb_ratio', default=50, type=float,
                    help='imbalance ratio for long-tail dataset generation')
parser.add_argument('--save_ckpt', action='store_true',
                    help='whether to save model checkpoints')
parser.add_argument('--sampling_gamma_range', default='0.5,1.0', type=str,
                    help='range for dynamic sampling weight gamma adjustment')
parser.add_argument('--meta_gamma', default=0.5, type=float,
                    help='trade-off weight between standard loss and meta loss')
parser.add_argument('--noisy_rate', default=0.3, type=float,
                    help='label noise rate for unreliable partial label generation')
parser.add_argument('--minority_weight_cap', default=5.0, type=float,
                    help='maximum cap for minority class sampling weights')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 
                    choices=['resnet18', 'resnet34', 'resnet50', 'convnet'],
                    help='model architecture (default: resnet18)')
parser.add_argument('--use_dirichlet', action='store_true',
                    help='whether to use Dirichlet distribution for sampling')
parser.add_argument('--dirichlet_lambda', default=10.0, type=float,
                    help='lambda parameter for Dirichlet distribution')
parser.add_argument('--dirichlet_beta', default=0.0, type=float,
                    help='beta parameter for Dirichlet distribution')

def generate_unreliable_candidate_labels(train_labels, partial_rate=0.1, noisy_rate=0):
    train_labels = torch.LongTensor(train_labels)
    if torch.min(train_labels) > 1:
        raise RuntimeError('Label minimum value error')
    elif torch.min(train_labels) == 1:
        train_labels = train_labels - 1

    K = int(torch.max(train_labels) - torch.min(train_labels) + 1)
    n = train_labels.shape[0]

    partialY = torch.zeros(n, K)
    transition_matrix = np.eye(K) * (1 - noisy_rate)
    transition_matrix[np.where(~np.eye(transition_matrix.shape[0], dtype=bool))] = partial_rate
    print(transition_matrix)

    for j in range(n):
        while partialY[j].sum() == 0:
            random_n_j = np.random.uniform(0, 1, size=(1, K))
            partialY[j] = torch.from_numpy((random_n_j <= transition_matrix[train_labels[j]]) * 1)

    if noisy_rate == 0:
        partialY[torch.arange(n), train_labels] = 1.0
        print('Reset true labels in candidate set')
    
    avg_C = torch.sum(partialY) / partialY.size(0)
    print(f"Candidate label generation complete. Average candidates: {avg_C:.2f}\n")
    return partialY, avg_C

class Trainer():
    def __init__(self, args):
        self.args = args

        model_path = '{ds}_{pr}_ql{ql}_rho{rho}_gm{gm}_t{t}_ep{we}_{ee}_imb_{it}{imf}_sd_{seed}'.format(
                                            ds=args.dataset,
                                            pr=args.partial_rate,
                                            ep=args.epochs,
                                            ql=args.queue_length,
                                            rho=args.rho_range,
                                            it=args.imb_type,
                                            imf=args.imb_factor,
                                            seed=args.seed,
                                            gm=args.gamma,
                                            t=args.tau,
                                            we=args.warmup_epoch,
                                            ee=args.est_epochs)
        args.exp_dir = os.path.join(args.exp_dir, model_path)
        if not os.path.exists(args.exp_dir):
            os.makedirs(args.exp_dir)

        if not hasattr(args, 'seed') or args.seed is None:
            args.seed = random.randint(0, 9999)
            print(f"==> No seed specified, using random seed: {args.seed}")

        random.seed(args.seed)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
        cudnn.deterministic = True

        # Load dataset
        if args.dataset == 'cifar10':
            train_loader, train_givenY, test_loader, est_loader, init_label_dist, train_label_cnt = load_cifar(args=args)
            many_shot_num = 3
            low_shot_num = 3
        elif args.dataset == 'cifar100':
            train_loader, train_givenY, test_loader, est_loader, init_label_dist, train_label_cnt = load_cifar(args=args)
            many_shot_num = 33
            low_shot_num = 33
        else:
            raise NotImplementedError("Unsupported dataset.")

        # Generate unreliable partial labels
        print("==> Generating unreliable partial labels...")
        full_train_dataset = train_loader.dataset
        true_labels = torch.tensor([full_train_dataset.true_labels[i] for i in range(len(full_train_dataset))])
        
        unreliable_labels, avg_C = generate_unreliable_candidate_labels(
            train_labels=true_labels.numpy(),
            partial_rate=args.partial_rate,
            noisy_rate=args.noisy_rate
        )
        
        # Update dataset with unreliable labels
        full_train_dataset.given_label_matrix = unreliable_labels
        train_givenY = unreliable_labels.cuda()
        
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.est_loader = est_loader
        self.init_label_dist = init_label_dist
        self.train_givenY = train_givenY
        self.acc_shot = AccurracyShot(train_label_cnt, args.num_class, many_shot_num, low_shot_num)
        self.low_shot_num = low_shot_num
        self.pseudo_labels_storage = torch.zeros(len(self.train_loader.dataset), args.num_class).cuda()

    def train(self, emp_dist=None, is_est_dist=False, total_epochs=0, gamma=0):
        print(f"=> Creating model '{self.args.arch}'")
    
        if self.args.arch == 'resnet18':
            model = resnet18(num_class=self.args.num_class)
        elif self.args.arch == 'resnet34':
            model = resnet34(num_class=self.args.num_class)
        elif self.args.arch == 'resnet50':
            model = resnet50(num_class=self.args.num_class, pretrained=True)
        elif self.args.arch == 'convnet':
            model = convnet(num_class=self.args.num_class)
        else:
            raise ValueError(f"Unsupported architecture: {self.args.arch}")
        
        model = model.cuda(self.args.gpu)

        optimizer = torch.optim.SGD(model.parameters(), self.args.lr,
                                    momentum=self.args.momentum,
                                    weight_decay=self.args.weight_decay)
        loss_fn = partial_loss(self.train_givenY)
        queue = None
        if self.args.queue_length > 0 and queue is None:
            queue = torch.zeros(self.args.queue_length, self.args.num_class).cuda()

        best_acc = 0

        stage = 'Prior Estimation' if is_est_dist else 'Final Training'
        total_epochs = self.args.epochs if not is_est_dist else total_epochs
        print(f'------------- Stage: {stage} --------------')
        with open(os.path.join(self.args.exp_dir, 'result.log'), 'a+') as f:
            f.write(f'------------- Stage: {stage} --------------\n')

        if emp_dist is None:
            emp_dist = self.init_label_dist.unsqueeze(dim=1)

        for epoch in range(total_epochs):
            is_best = False
            adjust_learning_rate(self.args, optimizer, epoch)

            print(f'\n==> Epoch [{epoch+1}/{total_epochs}] (LR: {optimizer.param_groups[0]["lr"]:.6f})')

            self.train_loop(model, loss_fn, queue, emp_dist, optimizer, epoch, is_est_dist, total_epochs)

            emp_dist_train = self.estimate_empirical_distribution(model, self.est_loader, num_class=self.args.num_class)
            emp_dist = emp_dist_train * gamma + emp_dist * (1 - gamma)

            acc_test, acc_many, acc_med, acc_few = self.test(model, self.test_loader)

            print(f'==> Epoch [{epoch+1}/{total_epochs}] Complete: Acc {acc_test:.2f}%, Best {best_acc:.2f}%, Shot [Many:{acc_many:.2f}% Med:{acc_med:.2f}% Few:{acc_few:.2f}%]')

            with open(os.path.join(self.args.exp_dir, 'result.log'), 'a+') as f:
                f.write(f'Epoch {epoch}: Acc {acc_test:.2f}, Best {best_acc:.2f}, Shot - Many {acc_many:.2f}/ Med {acc_med:.2f}/ Few {acc_few:.2f}. (LR {optimizer.param_groups[0]["lr"]:.5f})\n')

            if acc_test > best_acc:
                best_acc = acc_test
                is_best = True
                print(f'*** New best accuracy: {best_acc:.2f}% ***')

            if not is_est_dist and self.args.save_ckpt:
                self.save_checkpoint({
                    'epoch': epoch + 1,
                    'arch': self.args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer' : optimizer.state_dict(),
                }, is_best=is_best, filename=f'{self.args.exp_dir}/checkpoint.pth.tar',
                best_file_name=f'{self.args.exp_dir}/checkpoint_best.pth.tar')

        print(f'\n==> {stage} Complete, Best Accuracy: {best_acc:.2f}%')
        return emp_dist

    def train_loop(self, model, loss_fn, queue, emp_dist, optimizer, epoch, is_est_dist, total_epochs):
        args = self.args
        train_loader = self.train_loader

        batch_time = AverageMeter('Time', ':1.2f')
        data_time = AverageMeter('Data', ':1.2f')
        acc_cls = AverageMeter('Cls Acc', ':2.2f')
        acc_sink = AverageMeter('Sink Acc', ':2.2f')
        loss_cls_log = AverageMeter('Cls Loss', ':2.2f')
        loss_sink_log = AverageMeter('Sink Loss', ':2.2f')
        progress = ProgressMeter(
            len(train_loader),
            [batch_time, data_time, acc_cls, acc_sink, loss_cls_log, loss_sink_log],
            prefix=f"Epoch: [{epoch}]")

        model.train()

        eta = args.eta * linear_rampup(epoch, args.warmup_epoch)
        rho = args.rho_start + (args.rho_end - args.rho_start) * linear_rampup(epoch, args.warmup_epoch)

        end = time.time()

        # Dynamic sampling weight computation
        if not is_est_dist and epoch > 0:
            current_sampling_gamma = args.sampling_gamma_start + \
                                     (args.sampling_gamma_end - args.sampling_gamma_start) * (epoch / total_epochs)

            confidence_scores = loss_fn.confidence
            emp_dist_cpu = emp_dist.squeeze().cpu()

            emp_dist_normalized = emp_dist_cpu / (emp_dist_cpu.sum() + 1e-8)
            emp_dist_inv = (1.0 / (emp_dist_normalized + 1e-6)) ** current_sampling_gamma

            # Dirichlet sampling (if enabled)
            if args.use_dirichlet:
                print(f"==> [Dirichlet Sampling] lambda={args.dirichlet_lambda}, beta={args.dirichlet_beta}")
                dirichlet_concentration = args.dirichlet_lambda * emp_dist_inv + args.dirichlet_beta
                dirichlet_concentration = torch.clamp(dirichlet_concentration, min=1e-6)
                emp_dist_inv = torch.distributions.Dirichlet(dirichlet_concentration).sample()
                print(f"==> [After Dirichlet] min={emp_dist_inv.min():.6f}, max={emp_dist_inv.max():.6f}, mean={emp_dist_inv.mean():.6f}")

            sampling_weights = (confidence_scores.cpu() * emp_dist_inv).sum(dim=1, keepdim=True)
            sampling_weights = sampling_weights + 1e-8

            weight_max = sampling_weights.mean() * args.minority_weight_cap
            sampling_weights = torch.clamp(sampling_weights, max=weight_max)
            sampling_weights = sampling_weights / sampling_weights.sum() * len(sampling_weights)

            print(f"==> Epoch {epoch}, Sampling gamma: {current_sampling_gamma:.4f}")
            print(f"==> Weight stats: min={sampling_weights.min():.6f}, max={sampling_weights.max():.6f}, mean={sampling_weights.mean():.6f}, ratio={sampling_weights.max()/sampling_weights.min():.2f}")

        for i, (images_w, images_s, labels, true_labels, index) in enumerate(train_loader):
            data_time.update(time.time() - end)

            X_w, X_s, Y, index = images_w.cuda(), images_s.cuda(), labels.cuda(), index.cuda()
            Y_true = true_labels.long().detach().cuda()

            logits_w = model(X_w)
            logits_s = model(X_s)
            bs = args.batch_size

            prediction = F.softmax(logits_w.detach(), dim=1)
            sinkhorn_cost = prediction * Y
            conf_rn = sinkhorn_cost / sinkhorn_cost.sum(dim=1).repeat(prediction.size(1), 1).transpose(0, 1)

            # Sinkhorn pseudo-label generation
            prediction_queue = sinkhorn_cost.detach()
            if queue is not None:
                if not torch.all(queue[-1, :] == 0):
                    prediction_queue = torch.cat((queue, prediction_queue))
                queue[bs:] = queue[:-bs].clone().detach()
                queue[:bs] = prediction_queue[-bs:].clone().detach()
            pseudo_label_soft, flag = sinkhorn(prediction_queue, args.lamd, r_in=emp_dist)
            pseudo_label = pseudo_label_soft[-bs:]
            pseudo_label_idx = pseudo_label.max(dim=1)[1]

            _, rn_loss_vec = loss_fn(logits_w, index)
            _, pseudo_loss_vec = loss_fn(logits_w, None, targets=pseudo_label)

            # High-quality sample selection
            idx_chosen_sm = []
            sel_flags = torch.zeros(X_w.shape[0]).cuda().detach()
            for j in range(args.num_class):
                indices = np.where(pseudo_label_idx.cpu().numpy()==j)[0]
                if len(indices) == 0:
                    continue
                bs_j = bs * emp_dist[j]
                pseudo_loss_vec_j = pseudo_loss_vec[indices]
                sorted_idx_j = pseudo_loss_vec_j.sort()[1].cpu().numpy()
                partition_j = max(min(int(math.ceil(bs_j*rho)), len(indices)), 1)
                idx_chosen_sm.append(indices[sorted_idx_j[:partition_j]])

            idx_chosen_sm = np.concatenate(idx_chosen_sm)
            sel_flags[idx_chosen_sm] = 1
            high_conf_cond = (pseudo_label * prediction).sum(dim=1) > args.tau
            sel_flags[high_conf_cond] = 1
            idx_chosen = torch.where(sel_flags == 1)[0]
            idx_unchosen = torch.where(sel_flags == 0)[0]

            # Standard loss computation
            if epoch < 1 or idx_chosen.shape[0] == 0:
                loss = rn_loss_vec.mean()
            else:
                loss_unreliable = rn_loss_vec[idx_unchosen].mean() if idx_unchosen.shape[0] > 0 else 0
                loss_sin = pseudo_loss_vec[idx_chosen].mean()
                loss_cons, _ = loss_fn(logits_s[idx_chosen], None, targets=pseudo_label[idx_chosen])

                l = np.random.beta(4, 4)
                l = max(l, 1-l)
                X_w_c = X_w[idx_chosen]
                pseudo_label_c = pseudo_label[idx_chosen]
                idx_t = torch.randperm(X_w_c.size(0))
                X_w_c_rand = X_w_c[idx_t]
                pseudo_label_c_rand = pseudo_label_c[idx_t]
                X_w_c_mix = l * X_w_c + (1 - l) * X_w_c_rand
                pseudo_label_c_mix = l * pseudo_label_c + (1 - l) * pseudo_label_c_rand
                logits_mix = model(X_w_c_mix)
                loss_mix, _ = loss_fn(logits_mix, None, targets=pseudo_label_c_mix)

                loss = (loss_sin + loss_mix + loss_cons) * eta + loss_unreliable * (1 - eta)

            # Compute standard gradients
            optimizer.zero_grad()
            loss.backward()
            standard_grads = []
            for param in model.parameters():
                if param.grad is not None:
                    standard_grads.append(param.grad.clone().detach())
                else:
                    standard_grads.append(None)

            optimizer.zero_grad()

            # Meta-learning (only in final training stage and epoch > 0)
            if not is_est_dist and epoch > 0:
                idx_Da2 = torch.multinomial(sampling_weights.squeeze(), bs, replacement=False)
                X_Da2 = [self.train_loader.dataset[i][0] for i in idx_Da2]
                X_Da2 = torch.stack(X_Da2).to(args.gpu)
                Y_Da2_confidence = loss_fn.confidence[idx_Da2].to(args.gpu)

                X_Da1 = X_w.detach()
                Y_Da1 = pseudo_label.detach()

                with higher.innerloop_ctx(model, optimizer, copy_initial_weights=True, track_higher_grads=True) as (fmodel, diffopt):
                    # Inner loop: support set mixup loss
                    if idx_chosen.shape[0] > 0:
                        l_meta1 = np.random.beta(4, 4)
                        l_meta1 = max(l_meta1, 1-l_meta1)
                        X_Da1_c = X_Da1[idx_chosen]
                        Y_Da1_c = Y_Da1[idx_chosen]

                        if X_Da1_c.shape[0] > 1:
                            idx_meta1 = torch.randperm(X_Da1_c.size(0))
                            X_Da1_c_rand = X_Da1_c[idx_meta1]
                            Y_Da1_c_rand = Y_Da1_c[idx_meta1]
                            X_Da1_c_mix = l_meta1 * X_Da1_c + (1 - l_meta1) * X_Da1_c_rand
                            Y_Da1_c_mix = l_meta1 * Y_Da1_c + (1 - l_meta1) * Y_Da1_c_rand

                            logits_Da1_mix = fmodel(X_Da1_c_mix)
                            loss_meta_train, _ = loss_fn(logits_Da1_mix, None, targets=Y_Da1_c_mix)
                        else:
                            logits_Da1_c = fmodel(X_Da1_c)
                            loss_meta_train, _ = loss_fn(logits_Da1_c, None, targets=Y_Da1_c)
                    else:
                        logits_Da1_all = fmodel(X_Da1)
                        loss_meta_train, _ = loss_fn(logits_Da1_all, None, targets=Y_Da1)

                    diffopt.step(loss_meta_train)

                    # Query set mixup loss (on updated model)
                    if X_Da2.shape[0] > 1:
                        l_meta2 = np.random.beta(4, 4)
                        l_meta2 = max(l_meta2, 1-l_meta2)
                        idx_meta2 = torch.randperm(X_Da2.size(0))
                        X_Da2_rand = X_Da2[idx_meta2]
                        Y_Da2_rand = Y_Da2_confidence[idx_meta2]
                        X_Da2_mix = l_meta2 * X_Da2 + (1 - l_meta2) * X_Da2_rand
                        Y_Da2_mix = l_meta2 * Y_Da2_confidence + (1 - l_meta2) * Y_Da2_rand

                        logits_Da2_mix = fmodel(X_Da2_mix)
                        loss_meta_test, _ = loss_fn(logits_Da2_mix, None, targets=Y_Da2_mix)
                    else:
                        logits_Da2 = fmodel(X_Da2)
                        loss_meta_test, _ = loss_fn(logits_Da2, None, targets=Y_Da2_confidence)

                    meta_grads = torch.autograd.grad(loss_meta_test, fmodel.parameters(time=0))

                # Combine gradients
                for i, param in enumerate(model.parameters()):
                    if standard_grads[i] is not None:
                        param.grad = standard_grads[i] + args.meta_gamma * meta_grads[i]
            else:
                for i, param in enumerate(model.parameters()):
                    if standard_grads[i] is not None:
                        param.grad = standard_grads[i]

            optimizer.step()

            # Update confidence
            loss_fn.confidence_update(conf_rn, index)

            # Record metrics
            loss_sink_log.update(pseudo_loss_vec.mean().item())
            loss_cls_log.update(rn_loss_vec.mean().item())

            acc = accuracy(logits_w, Y_true)[0]
            acc_cls.update(acc[0])
            acc = accuracy(pseudo_label, Y_true)[0]
            acc_sink.update(acc[0])

            batch_time.update(time.time() - end)
            end = time.time()
            if i % args.print_freq == 0:
                progress.display(i)

        print(f'==> Train Summary - Cls Acc: {acc_cls.avg:.2f}%, Sink Acc: {acc_sink.avg:.2f}%, '
            f'Cls Loss: {loss_cls_log.avg:.3f}, Sink Loss: {loss_sink_log.avg:.3f}')

    def test(self, model, test_loader):
        with torch.no_grad():
            print('==> Evaluating...')
            model.eval()
            pred_list = []
            true_list = []
            for _, (images, labels) in enumerate(test_loader):
                images = images.cuda()
                outputs = model(images)
                pred = F.softmax(outputs, dim=1)
                pred_list.append(pred.cpu())
                true_list.append(labels)

            pred_list = torch.cat(pred_list, dim=0)
            true_list = torch.cat(true_list, dim=0)

            acc1, acc5 = accuracy(pred_list, true_list, topk=(1, 5))
            acc_many, acc_med, acc_few = self.acc_shot.get_shot_acc(pred_list.max(dim=1)[1], true_list)
            print(f'==> Test Acc: {acc1.item():.2f}% ({acc5.item():.2f}%), [Many {acc_many.item():.2f}%, Med {acc_med.item():.2f}%, Few {acc_few.item():.2f}%]')
        return float(acc1), float(acc_many), float(acc_med), float(acc_few)

    def estimate_empirical_distribution(self, model, est_loader, num_class):
        with torch.no_grad():
            print('==> Estimating empirical label distribution (soft)...')
            model.eval()
            prob_list = []

            for _, (images, labels, _) in enumerate(est_loader):
                images = images.cuda()
                outputs = model(images)
                probs = torch.softmax(outputs, dim=1)

                probs = probs * labels.cuda()
                probs = probs / (probs.sum(dim=1, keepdim=True) + 1e-12)

                prob_list.append(probs.cpu())

        probs_all = torch.cat(prob_list, dim=0)
        emp_dist = probs_all.mean(dim=0)
        emp_dist = emp_dist / emp_dist.sum()

        return emp_dist.unsqueeze(1)

    def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar', best_file_name='model_best.pth.tar'):
        torch.save(state, filename)
        if is_best:
            shutil.copyfile(filename, best_file_name)

if __name__ == '__main__':
    args = parser.parse_args()

    [args.rho_start, args.rho_end] = [float(item) for item in args.rho_range.split(',')]
    [args.gamma1, args.gamma2] = [float(item) for item in args.gamma.split(',')]
    [args.sampling_gamma_start, args.sampling_gamma_end] = [float(item) for item in args.sampling_gamma_range.split(',')]
    iterations = args.lr_decay_epochs.split(',')
    args.lr_decay_epochs = list([])
    for it in iterations:
        args.lr_decay_epochs.append(int(it))
    args.queue_length *= args.batch_size
    print(args)
    torch.cuda.set_device(args.gpu)
    args.imb_factor = 1. / args.imb_ratio
    trainer = Trainer(args)
    emp_dist = trainer.train(is_est_dist=True, total_epochs=args.est_epochs, gamma=args.gamma1)
    trainer.train(emp_dist=emp_dist, gamma=args.gamma2)
