import torch
import parsing
from utils import *


parser = parsing.create_parser()
args = parser.parse_args()
state = {k: v for k, v in args._get_kwargs()}

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def sampleMix(input_tensor, rate_1, rate_2):
    new_tensor = torch.zeros_like(input_tensor)
    batch_size = input_tensor.shape[0]

    for i in range(batch_size):
        indices = list(range(batch_size))
        indices.remove(i)
        random_index = torch.randint(0, batch_size - 1, (1,)).item()
        selected_sample = input_tensor[indices[random_index]]

        # mixup
        new_sample = rate_1 * input_tensor[i] + rate_2 * selected_sample
        new_tensor[i] = new_sample
    return new_tensor

'''
***The code below is based on "On the Power of Deep but Naive Partial Label Learning (DNPL)", ICASSP 21***
'''

class T_DNPL():
    def __init__(self):
        super(T_DNPL, self).__init__()

    def train_step(self, index, confidence, input, input_w, input_s, part_y, model, optimizer, epoch):
        optimizer.zero_grad()

        output = model(input)
        s = part_y
        s_hat = F.softmax(output, dim=1)
        ss_hat = s * s_hat
        ss_hat_dp = ss_hat.sum(1)
        ss_hat_dp = torch.clamp(ss_hat_dp, 0., 1.)
        loss = -torch.mean(torch.log(ss_hat_dp + 1e-10))

        loss.backward()

        optimizer.step()

        return loss.item(), confidence


'''
***The code below is based on "Progressive Identification of True Labels for Partial-Label Learning (PRODEN)", ICML 20***
'''


class T_PRODEN():
    def __init__(self):
        super(T_PRODEN, self).__init__()

    def confidence_update(self, index, confidence, outputs, part_y):
        revisedY = confidence[index, :].clone()
        revisedY[revisedY > 0] = 1
        revisedY = revisedY * outputs
        revisedY = revisedY / revisedY.sum(dim=1).repeat(revisedY.size(1), 1).transpose(0, 1)

        confidence[index, :] = revisedY.detach()

        return confidence

    def train_step(self, index, confidence, input, input_w, input_s, part_y, model, optimizer, epoch):
        optimizer.zero_grad()

        outputs = F.softmax(model(input), dim=1)
        l = confidence[index, :] * torch.log(outputs)

        loss = (-torch.sum(l)) / l.size(0)

        loss.backward()
        optimizer.step()

        if args.use_confidence == True:
            confidence = self.confidence_update(index, confidence, outputs, part_y)

        return loss.item(), confidence


'''
***The code below is based on "Exploiting Class Activation Value for Partial-Label Learning (CAVL)", ICLR 22***
'''


class T_CAVL():
    def __init__(self):
        super(T_CAVL, self).__init__()

    def confidence_update(self, index, confidence, outputs, part_y):
        with torch.no_grad():
            cav = (outputs * torch.abs(1 - outputs)) * part_y
            cav_pred = torch.max(cav, dim=1)[1]
            gt_label = F.one_hot(cav_pred, part_y.shape[
                1])  # label_smoothing() could be used to further improve the performance for some datasets
            confidence[index, :] = gt_label.float()

        return confidence

    def train_step(self, index, confidence, input, input_w, input_s, part_y, model, optimizer, epoch):
        optimizer.zero_grad()

        outputs = model(input)
        logsm_outputs = F.log_softmax(outputs, dim=1)
        final_outputs = logsm_outputs * confidence[index, :]

        loss = - ((final_outputs).sum(dim=1)).mean()

        loss.backward()
        optimizer.step()

        if args.use_confidence == True:
            confidence = self.confidence_update(index, confidence, outputs, part_y)

        return loss.item(), confidence


'''
***The code below is based on "Leveraged Weighted Loss for Partial Label Learning (LW)", ICML 21***
'''


class T_LW():
    def __init__(self):
        super(T_LW, self).__init__()

    def confidence_update(self, index, confidence, outputs, part_y):
        with torch.no_grad():
            sm_outputs = F.softmax(outputs, dim=1)
            onezero = torch.zeros(sm_outputs.shape[0], sm_outputs.shape[1])
            onezero[part_y > 0] = 1
            counter_onezero = 1 - onezero
            onezero = onezero.to(device)
            counter_onezero = counter_onezero.to(device)
            new_weight1 = sm_outputs * onezero
            new_weight1 = new_weight1 / (new_weight1 + 1e-8).sum(dim=1).repeat(
                confidence.shape[1], 1).transpose(0, 1)
            new_weight2 = sm_outputs * counter_onezero
            new_weight2 = new_weight2 / (new_weight2 + 1e-8).sum(dim=1).repeat(
                confidence.shape[1], 1).transpose(0, 1)
            new_weight = new_weight1 + new_weight2
            confidence[index, :] = new_weight
            return confidence

    def train_step(self, index, confidence, input, input_w, input_s, part_y, model, optimizer, epoch):
        optimizer.zero_grad()
        outputs = model(input)
        device = outputs.device
        onezero = torch.zeros(outputs.shape[0], outputs.shape[1])
        onezero[part_y > 0] = 1
        counter_onezero = 1 - onezero
        onezero = onezero.to(device)
        counter_onezero = counter_onezero.to(device)

        # loss 1 is applied on candidate labels and loss 2 is applied on non-candidate labels.
        if args.loss == 'sigmoid':
            sig_loss1 = 0.5 * torch.ones(outputs.shape[0], outputs.shape[1])
            sig_loss1 = sig_loss1.to(device)
            sig_loss1[outputs < 0] = 1 / (1 + torch.exp(outputs[outputs < 0]))
            sig_loss1[outputs > 0] = torch.exp(-outputs[outputs > 0]) / (
                    1 + torch.exp(-outputs[outputs > 0]))
            if args.use_confidence == True:
                l1 = confidence[index, :] * onezero * sig_loss1
            else:
                l1 = onezero * sig_loss1
            average_loss1 = torch.sum(l1) / l1.size(0)
            sig_loss2 = 0.5 * torch.ones(outputs.shape[0], outputs.shape[1])
            sig_loss2 = sig_loss2.to(device)
            sig_loss2[outputs > 0] = 1 / (1 + torch.exp(-outputs[outputs > 0]))
            sig_loss2[outputs < 0] = torch.exp(
                outputs[outputs < 0]) / (1 + torch.exp(outputs[outputs < 0]))
            if args.use_confidence == True:
                l2 = confidence[index, :] * counter_onezero * sig_loss2
            else:
                l2 = counter_onezero * sig_loss2
            average_loss2 = torch.sum(l2) / l2.size(0)
            loss = average_loss1 + args.beta * average_loss2


        elif args.loss == 'cross_entropy':

            sm_outputs = F.softmax(outputs, dim=1)

            sig_loss1 = - torch.log(sm_outputs + 1e-8)
            if args.use_confidence == True:
                l1 = confidence[index, :] * onezero * sig_loss1
            else:
                l1 = onezero * sig_loss1
            average_loss1 = torch.sum(l1) / l1.size(0)

            sig_loss2 = - torch.log(1 - sm_outputs + 1e-8)
            if args.use_confidence == True:
                l2 = confidence[index, :] * counter_onezero * sig_loss2
            else:
                l2 = counter_onezero * sig_loss2
            average_loss2 = torch.sum(l2) / l2.size(0)

            loss = average_loss1 + args.beta * average_loss2

        else:
            raise Exception('Need to choose the loss')

        loss.backward()
        optimizer.step()

        if args.use_confidence == True:
            confidence = self.confidence_update(index, confidence, outputs, part_y)

        return loss.item(), confidence


'''
***The code below is based on "Revisiting Consistency Regularization for Deep Partial Label Learning (CR)", ICML 22***
'''


class T_CR():
    def __init__(self):
        super(T_CR, self).__init__()

    def confidence_update(self, index, confidence, y_pred_aug0_probas, y_pred_aug1_probas, y_pred_aug2_probas, part_y):

        part_y[part_y > 0] = 1
        y_pred_aug0_probas, y_pred_aug1_probas, y_pred_aug2_probas = map(lambda x: x.detach(), (
            y_pred_aug0_probas, y_pred_aug1_probas, y_pred_aug2_probas))

        revisedY0 = part_y.clone()

        revisedY0 = revisedY0 * torch.pow(y_pred_aug0_probas, 1 / (2 + 1)) \
                    * torch.pow(y_pred_aug1_probas, 1 / (2 + 1)) \
                    * torch.pow(y_pred_aug2_probas, 1 / (2 + 1))
        revisedY0 = revisedY0 / revisedY0.sum(dim=1).repeat(args.num_class, 1).transpose(0, 1)

        confidence[index, :] = revisedY0.detach()

        return confidence

    def train_step(self, index, confidence, input, input_w, input_s, part_y, model, optimizer, epoch):

        part_y[part_y > 0] = 1
        consistency_criterion = nn.KLDivLoss(reduction='batchmean').to(device)

        optimizer.zero_grad()

        output, weak_output, strong_output = map(lambda x: model(x), (input, input_w, input_s))

        consistency_loss, consistency_loss_weak, consistency_loss_strong = \
            map(lambda x: consistency_criterion(torch.log_softmax(x, dim=-1), confidence[index, :]),
                (output, weak_output, strong_output))

        super_loss = -torch.mean(torch.sum(torch.log(1.0000001 - F.softmax(output, dim=1)) * (1 - part_y), dim=1))

        if args.use_confidence == True:
            lam = min((epoch / args.epochs) * args.lam, args.lam)
        else:
            lam = 0
        loss = super_loss + lam * (
                args.c_weight * consistency_loss + args.c_weight_w * consistency_loss_weak + args.c_weight_s * consistency_loss_strong)
        loss.backward()
        optimizer.step()
        if args.use_confidence == True:
            confidence = self.confidence_update(index, confidence, torch.softmax(output, dim=-1),
                                                torch.softmax(weak_output, dim=-1),
                                                torch.softmax(strong_output, dim=-1), part_y)
        return loss.item(), confidence


'''
***The code below is based on "PiCO: Contrastive Label Disambiguation for Partial Label Learning (PiCO)", ICLR 22***
'''


class T_PiCO():
    def __init__(self):
        super(T_PiCO, self).__init__()

    def train_step(self, index, confidence, input, input_w, input_s, part_y, model, optimizer, epoch, loss_fn):
        loss_cont_fn = SupConMocoLoss()
        cls_out, features_cont, pseudo_score_cont, partial_target_cont, score_prot = model(input, input, part_y, args,
                                                                                           eval_only=False)
        pseudo_target_max, pseudo_target_cont = torch.max(pseudo_score_cont, dim=1)
        pseudo_target_cont = pseudo_target_cont.contiguous().view(-1, 1)
        if args.use_confidence == True:
            loss_fn.confidence_update(temp_un_conf=score_prot, batch_index=index, batchY=part_y)
            # warm up ended
            mask = torch.eq(pseudo_target_cont[:args.batch_size], pseudo_target_cont.T).float().to(device)
            # get positive set by contrasting predicted labels
        else:
            mask = None
        # contrastive loss
        loss_cont = loss_cont_fn(features=features_cont, mask=mask, batch_size=args.batch_size)
        # classification loss
        if args.method == "PiCO_DisambiguationDelAllCandi" or args.method == "PiCO_Without_Candi":
            loss_cls = loss_fn(cls_out, index, part_y)
        else:
            loss_cls = loss_fn(cls_out, index)
        loss = loss_cls + args.gamma * loss_cont
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        return loss.item(), confidence


'''
***The code below is based on " Towards Effective Visual Representations for Partial-Label Learning", CVPR 23***
'''


class T_PaPI():
    def __init__(self):
        super(T_PaPI, self).__init__()

    def train_step(self, index, confidence, input, input_w, input_s, part_y, model, optimizer, epoch, loss_PaPi_func):
        sim_criterion = torch.nn.KLDivLoss(reduction='batchmean').cuda()

        Lambda = np.random.beta(args.alpha_mixup, args.alpha_mixup)
        idx_rp = torch.randperm(args.batch_size)

        # use no data augmentation
        input_w_rp = input[idx_rp]
        input_s_rp = input[idx_rp]

        X_1_mix = Lambda * input + (1 - Lambda) * input_w_rp
        X_2_mix = Lambda * input + (1 - Lambda) * input_s_rp

        input_rp = input[idx_rp]
        X_mix = Lambda * input + (1 - Lambda) * input_rp
        # cls_out, features_cont, pseudo_score_cont, partial_target_cont, score_prot \
        cls_out_1, cls_out_2, logits_prot_1, logits_prot_2, logits_prot_1_mix, logits_prot_2_mix = model(img_q=input,
                                                                                                         img_k=input,
                                                                                                         img_q_mix=X_1_mix,
                                                                                                         img_k_mix=X_2_mix,
                                                                                                         partial_Y=part_y,
                                                                                                         args=args,
                                                                                                         eval_only=False)
        if args.use_confidence == True:
            loss_PaPi_func.update_weight_byclsout1(cls_predicted_score=cls_out_1, batch_index=index,
                                                   batch_partial_Y=part_y,
                                                   args=args)

        cls_loss_1, sim_loss_2, alpha_td = loss_PaPi_func(cls_out_1, cls_out_2, logits_prot_1, logits_prot_2,
                                                          logits_prot_1_mix, logits_prot_2_mix, idx_rp, Lambda, index,
                                                          args, sim_criterion)

        loss = cls_loss_1 + alpha_td * sim_loss_2
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        return loss.item(), confidence

def add_gaussian_noise(input_tensor, noise_std=0.2):
    # Generate noise data with the same shape as the input tensor
    noise = torch.randn_like(input_tensor) * noise_std
    # Add the noise to the original input
    return input_tensor + noise


def mask_features(input_tensor, mask_percentage=0.2):
    """
    Masks 20% of the features in each sample by setting them to zero.

    Parameters:
    input_tensor (torch.Tensor): The input tensor of shape (8, 1, 310).
    mask_percentage (float): The percentage of features to mask, default is 0.2 (20%).

    Returns:
    torch.Tensor: The tensor with masked features.
    """
    batch_size, _, num_features = input_tensor.shape
    num_features_to_mask = int(num_features * mask_percentage)

    # Create a mask of zeros
    mask = torch.ones_like(input_tensor)

    for i in range(batch_size):
        # Randomly select indices to mask
        mask_indices = torch.randperm(num_features)[:num_features_to_mask]
        mask[i, 0, mask_indices] = 0

    # Apply the mask to the input tensor
    masked_tensor = input_tensor * mask
    return masked_tensor


# the Noise Augmentation method
def mix_beta(input_tensor, beta_parameter):
    lambda_prime = np.random.beta(beta_parameter, beta_parameter)

    # Scale the sampled results to the range [0.8, 1]
    lambda_scaled = 0.8 + 0.2 * lambda_prime

    new_tensor = torch.zeros_like(input_tensor)
    batch_size = input_tensor.shape[0]

    for i in range(batch_size):
        indices = list(range(batch_size))
        indices.remove(i)
        random_index = torch.randint(0, batch_size - 1, (1,)).item()
        selected_sample = input_tensor[indices[random_index]]

        # mixup
        new_sample = lambda_scaled * input_tensor[i] + (1-lambda_scaled) * selected_sample
        new_tensor[i] = new_sample
    return new_tensor

class T_PGNA_PL():
    def __init__(self):
        super(T_PGNA_PL, self).__init__()

    def train_step(self, index, confidence, input, input_w, input_s, part_y, model, optimizer, epoch):
        optimizer.zero_grad()

        # Noise Augmentation
        q_input = mix_beta(input, args.beta_parameter)
        cls_out, logits_prot = model(q_input, part_y, args, eval_only=False)

        # classification loss -- DNPL loss
        s = part_y
        s_hat = F.softmax(cls_out, dim=1)
        ss_hat = s * s_hat
        ss_hat_dp = ss_hat.sum(1)
        ss_hat_dp = torch.clamp(ss_hat_dp, 0., 1.)
        loss_cls = -torch.mean(torch.log(ss_hat_dp + 1e-10))

        # Proto-Guided self distilling
        t = args.temperature
        logits_student = cls_out / t
        logits_teacher = logits_prot / t
        probs_teacher = F.softmax(logits_teacher, dim=-1)
        distill_loss = F.kl_div(F.log_softmax(logits_student, dim=-1), probs_teacher, reduction='batchmean')
        loss = loss_cls + args.gamma * distill_loss
        loss.backward()
        optimizer.step()
        return loss.item(), confidence

class T_PGNA_PL_wtihout_PG():
    def __init__(self):
        super(T_PGNA_PL_wtihout_PG, self).__init__()

    def train_step(self, index, confidence, input, input_w, input_s, part_y, model, optimizer, epoch):
        optimizer.zero_grad()

        q_input = mix_beta(input, args.beta_parameter)
        cls_out, logits_prot = model(q_input, part_y, args, eval_only=False)

        # classification loss -- DNPL loss
        s = part_y
        s_hat = F.softmax(cls_out, dim=1)
        ss_hat = s * s_hat
        ss_hat_dp = ss_hat.sum(1)
        ss_hat_dp = torch.clamp(ss_hat_dp, 0., 1.)
        loss_cls = -torch.mean(torch.log(ss_hat_dp + 1e-10))
        loss_cls.backward()
        optimizer.step()
        return loss_cls.item(), confidence

class T_PGNA_PL_wtihout_NA():
    def __init__(self):
        super(T_PGNA_PL_wtihout_NA, self).__init__()

    def train_step(self, index, confidence, input, input_w, input_s, part_y, model, optimizer, epoch):
        optimizer.zero_grad()

        cls_out, logits_prot = model(input, part_y, args, eval_only=False)

        # classification loss -- DNPL loss
        s = part_y
        s_hat = F.softmax(cls_out, dim=1)
        ss_hat = s * s_hat
        ss_hat_dp = ss_hat.sum(1)
        ss_hat_dp = torch.clamp(ss_hat_dp, 0., 1.)
        loss_cls = -torch.mean(torch.log(ss_hat_dp + 1e-10))

        # self distilling
        t = args.temperature
        logits_student = cls_out / t
        logits_teacher = logits_prot / t
        probs_teacher = F.softmax(logits_teacher, dim=-1)
        distill_loss = F.kl_div(F.log_softmax(logits_student, dim=-1), probs_teacher, reduction='batchmean')
        loss = loss_cls + args.gamma * distill_loss
        loss.backward()
        optimizer.step()
        return loss.item(), confidence

#Compare with other NA methods
class T_PG_Other_NA__Mask():
    def __init__(self):
        super(T_PG_Other_NA__Mask, self).__init__()

    def train_step(self, index, confidence, input, input_w, input_s, part_y, model, optimizer, epoch):
        optimizer.zero_grad()

        q_input = mask_features(input, 0.2)
        cls_out, logits_prot = model(q_input, part_y, args, eval_only=False)

        # classification loss -- DNPL loss
        s = part_y
        s_hat = F.softmax(cls_out, dim=1)
        ss_hat = s * s_hat
        ss_hat_dp = ss_hat.sum(1)
        ss_hat_dp = torch.clamp(ss_hat_dp, 0., 1.)
        loss_cls = -torch.mean(torch.log(ss_hat_dp + 1e-10))

        # self distilling
        t = args.temperature
        logits_student = cls_out / t
        logits_teacher = logits_prot / t
        probs_teacher = F.softmax(logits_teacher, dim=-1)
        distill_loss = F.kl_div(F.log_softmax(logits_student, dim=-1), probs_teacher, reduction='batchmean')
        loss = loss_cls + args.gamma * distill_loss
        loss.backward()
        optimizer.step()
        return loss.item(), confidence

class T_PG_Other_NA__Gauss():
    def __init__(self):
        super(T_PG_Other_NA__Gauss, self).__init__()

    def train_step(self, index, confidence, input, input_w, input_s, part_y, model, optimizer, epoch):
        optimizer.zero_grad()

        q_input = add_gaussian_noise(input)
        cls_out, logits_prot = model(q_input, part_y, args, eval_only=False)

        # classification loss -- DNPL loss
        s = part_y
        s_hat = F.softmax(cls_out, dim=1)
        ss_hat = s * s_hat
        ss_hat_dp = ss_hat.sum(1)
        ss_hat_dp = torch.clamp(ss_hat_dp, 0., 1.)
        loss_cls = -torch.mean(torch.log(ss_hat_dp + 1e-10))

        # self distilling
        t = args.temperature
        logits_student = cls_out / t
        logits_teacher = logits_prot / t
        probs_teacher = F.softmax(logits_teacher, dim=-1)
        distill_loss = F.kl_div(F.log_softmax(logits_student, dim=-1), probs_teacher, reduction='batchmean')
        loss = loss_cls + args.gamma * distill_loss
        loss.backward()
        optimizer.step()
        return loss.item(), confidence

def add_salt_and_pepper_noise(input_tensor, salt_prob=0.1, pepper_prob=0.1):

    # Generate random numbers with the same shape as the input tensor
    random_tensor = torch.rand_like(input_tensor)
    # Create a tensor with the same shape as input_tensor to store the result of adding noise
    noisy_tensor = input_tensor.clone()
    # Add "salt" (set the values at positions where the random number is less than salt_prob to the maximum value)
    noisy_tensor[random_tensor < salt_prob] = 1.0
    # Add "pepper" (set the values at positions where the random number is greater than 1 - pepper_prob to the minimum value)
    noisy_tensor[random_tensor > (1 - pepper_prob)] = 0.0

    return noisy_tensor

class T_PG_Other_NA__PepperSalt():
    def __init__(self):
        super(T_PG_Other_NA__PepperSalt, self).__init__()

    def train_step(self, index, confidence, input, input_w, input_s, part_y, model, optimizer, epoch):
        optimizer.zero_grad()

        q_input = add_salt_and_pepper_noise(input)
        cls_out, logits_prot = model(q_input, part_y, args, eval_only=False)

        # classification loss -- DNPL loss
        s = part_y
        s_hat = F.softmax(cls_out, dim=1)
        ss_hat = s * s_hat
        ss_hat_dp = ss_hat.sum(1)
        ss_hat_dp = torch.clamp(ss_hat_dp, 0., 1.)
        loss_cls = -torch.mean(torch.log(ss_hat_dp + 1e-10))

        # self distilling
        t = args.temperature
        logits_student = cls_out / t
        logits_teacher = logits_prot / t
        probs_teacher = F.softmax(logits_teacher, dim=-1)
        distill_loss = F.kl_div(F.log_softmax(logits_student, dim=-1), probs_teacher, reduction='batchmean')
        loss = loss_cls + args.gamma * distill_loss
        loss.backward()
        optimizer.step()
        return loss.item(), confidence

#define mixup function
def mixup_origin_beta(input_tensor, label, beta_parameter):
    lambda_prime = np.random.beta(beta_parameter, beta_parameter)

    # Scale the sampled results to the range [0.8, 1]
    lambda_scaled = 0.8 + 0.2 * lambda_prime

    new_tensor = torch.zeros_like(input_tensor)
    new_label = torch.zeros_like(label)

    batch_size = input_tensor.shape[0]

    for i in range(batch_size):
        indices = list(range(batch_size))
        indices.remove(i)
        random_index = torch.randint(0, batch_size - 1, (1,)).item()
        selected_sample = input_tensor[indices[random_index]]
        selected_label = label[indices[random_index]]

        # mixup
        new_sample = lambda_scaled * input_tensor[i] + (1-lambda_scaled) * selected_sample
        new_label_sample = lambda_scaled * label[i] + (1-lambda_scaled) * selected_label
        new_tensor[i] = new_sample
        new_label[i] = new_label_sample
    return new_tensor, new_label

class T_PGNA_Other_OriginMixup():
    def __init__(self):
        super(T_PGNA_Other_OriginMixup, self).__init__()

    def train_step(self, index, confidence, input, input_w, input_s, part_y, model, optimizer, epoch):
        optimizer.zero_grad()

        # Noise Augmentation
        q_input, part_y = mixup_origin_beta(input, part_y, args.beta_parameter)
        cls_out, logits_prot = model(q_input, part_y, args, eval_only=False)

        # classification loss -- DNPL loss
        s = part_y
        s_hat = F.softmax(cls_out, dim=1)
        ss_hat = s * s_hat
        ss_hat_dp = ss_hat.sum(1)
        ss_hat_dp = torch.clamp(ss_hat_dp, 0., 1.)
        loss_cls = -torch.mean(torch.log(ss_hat_dp + 1e-10))

        # Proto-Guided self distilling
        t = args.temperature
        logits_student = cls_out / t
        logits_teacher = logits_prot / t
        probs_teacher = F.softmax(logits_teacher, dim=-1)
        distill_loss = F.kl_div(F.log_softmax(logits_student, dim=-1), probs_teacher, reduction='batchmean')
        loss = loss_cls + args.gamma * distill_loss
        loss.backward()
        optimizer.step()
        return loss.item(), confidence


class T_FullySupervision():
    def __init__(self):
        super(T_FullySupervision, self).__init__()
        self.criterion = torch.nn.CrossEntropyLoss()

    def train_step(self, input, origin_y, model, optimizer):
        optimizer.zero_grad()
        output = model(input)
        labels = torch.argmax(origin_y, dim=1)
        loss = self.criterion(output, labels)
        loss.backward()
        optimizer.step()

        return loss.item()

class T_PGNA_PL_FullySupervision():
    def __init__(self):
        super(T_PGNA_PL_FullySupervision, self).__init__()
        self.criterion = torch.nn.CrossEntropyLoss()

    def train_step(self, index, confidence, input, input_w, input_s, origin_y, model, optimizer, epoch):
        optimizer.zero_grad()

        # Noise Augmentation
        q_input = mix_beta(input, args.beta_parameter)
        cls_out, logits_prot = model(q_input, origin_y, args, eval_only=False)

        # classification loss -- DNPL loss
        # s = part_y
        # s_hat = F.softmax(cls_out, dim=1)
        # ss_hat = s * s_hat
        # ss_hat_dp = ss_hat.sum(1)
        # ss_hat_dp = torch.clamp(ss_hat_dp, 0., 1.)
        # loss_cls = -torch.mean(torch.log(ss_hat_dp + 1e-10))

        #classification loss -- cross entropy loss
        labels = torch.argmax(origin_y, dim=1)
        loss_cls = self.criterion(cls_out, labels)

        # Proto-Guided self distilling
        t = args.temperature
        logits_student = cls_out / t
        logits_teacher = logits_prot / t
        probs_teacher = F.softmax(logits_teacher, dim=-1)
        distill_loss = F.kl_div(F.log_softmax(logits_student, dim=-1), probs_teacher, reduction='batchmean')
        loss = loss_cls + args.gamma * distill_loss
        loss.backward()
        optimizer.step()
        return loss.item(), confidence
