import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as opt
import argparse
import scipy.sparse as sp
import dgl
import pickle, gzip
import sys
import itertools
from pyscipopt import SCIP_PARAMSETTING, Model, Branchrule, Nodesel, SCIP_RESULT
from torch.utils.data import Dataset
from treelib import Tree
sys.path.append('../')
from src.logger import Logger
from src.model import GCNN_Net, PD_Net
import time
logger = Logger.logger

def extract_state_new_32(model, buffer=None):


    if buffer is None or model.getNNodes() == 1:
        buffer = {}

    # 从缓冲区中更新状态
    # update state from buffer if any
    
    # s = model.getStateNew32_without_pc(buffer['scip_state'] if 'scip_state' in buffer else None)
    # s = model.getStateNew1(buffer['scip_state'] if 'scip_state' in buffer else None)
    s = model.getState(buffer['scip_state'] if 'scip_state' in buffer else None)
    buffer['scip_state'] = s

    if 'state' in buffer:
        obj_norm = buffer['state']['obj_norm']
    else:
        obj_norm = np.linalg.norm(s['col']['coefs'])
        obj_norm = 1 if obj_norm <= 0 else obj_norm

    row_norms = s['row']['norms']
    row_norms[row_norms == 0] = 1

    # 列的功能 -> 变量特征
    # Column features
    n_cols = len(s['col']['types'])

    
    if 'state' in buffer:
        col_feats = buffer['state']['col_feats']
    else:
        col_feats = {}
        col_feats['type'] = np.zeros((n_cols, 4))  # BINARY INTEGER IMPLINT CONTINUOUS
        col_feats['type'][np.arange(n_cols), s['col']['types']] = 1
        col_feats['coef_normalized'] = s['col']['coefs'].reshape(-1, 1) / obj_norm

    col_feats['has_lb'] = ~np.isnan(s['col']['lbs']).reshape(-1, 1)
    col_feats['has_ub'] = ~np.isnan(s['col']['ubs']).reshape(-1, 1)
    col_feats['sol_is_at_lb'] = s['col']['sol_is_at_lb'].reshape(-1, 1)
    col_feats['sol_is_at_ub'] = s['col']['sol_is_at_ub'].reshape(-1, 1)
    col_feats['sol_frac'] = s['col']['solfracs'].reshape(-1, 1)
    col_feats['sol_frac'][s['col']['types'] == 3] = 0  # continuous have no fractionality
    col_feats['basis_status'] = np.zeros((n_cols, 4))  # LOWER BASIC UPPER ZERO
    col_feats['basis_status'][np.arange(n_cols), s['col']['basestats']] = 1
    col_feats['reduced_cost'] = s['col']['redcosts'].reshape(-1, 1) / obj_norm
    col_feats['age'] = s['col']['ages'].reshape(-1, 1) / (s['stats']['nlps'] + 5)
    col_feats['sol_val'] = s['col']['solvals'].reshape(-1, 1)

    # s['col']['incvals'] = np.nan_to_num(s['col']['incvals'], nan=0.0)
    # s['col']['avgincvals'] = np.nan_to_num(s['col']['avgincvals'], nan=0.0)
    # col_feats['inc_val'] = s['col']['incvals'].reshape(-1, 1)
    # col_feats['avg_inc_val'] = s['col']['avgincvals'].reshape(-1, 1)

    # # # new fea
    # # depth 17 18
    # col_feats['depth_1'] = s['col']['depth_1'].reshape(-1, 1)
    # col_feats['depth_2'] = s['col']['depth_2'].reshape(-1, 1)
    # data:
    # col_depth_1[col_i] = 1 - (SCIPvarGetAvgBranchdepthCurrentRun(var, SCIP_BRANCHDIR_UPWARDS) / scip_now_max_depth)
    # col_depth_2[col_i] = 1 - (SCIPvarGetAvgBranchdepthCurrentRun(var, SCIP_BRANCHDIR_DOWNWARDS) / scip_now_max_depth)

    # # pc state 19 20 21 22
    # # col_feats['conflictscore'] = s['col']['conflictscore'].reshape(-1, 1)
    # # col_feats['conflengthscore'] = s['col']['conflengthscore'].reshape(-1, 1)
    # # col_feats['inferencescore'] = s['col']['inferencescore'].reshape(-1, 1)
    # # col_feats['cutoffscore'] = s['col']['cutoffscore'].reshape(-1, 1)
    # # # col_feats['pscostscor'] = s['col']['pscostscor'].reshape(-1, 1)
    # data:
    # col_conflictscore[col_i] = self.getVarScore(SCIPgetVarConflictScore(scip, var), SCIPgetAvgConflictScore(scip))
    # col_conflengthscore[col_i] = self.getVarScore(SCIPgetVarConflictlengthScore(scip, var), SCIPgetAvgConflictlengthScore(scip))
    # col_inferencescore[col_i] = self.getVarScore(SCIPgetVarAvgInferenceScore(scip, var), SCIPgetAvgInferenceScore(scip))
    # col_cutoffscore[col_i] = self.getVarScore(SCIPgetVarAvgCutoffScore(scip, var), SCIPgetAvgCutoffScore(scip))
    # col_pscostscor[col_i] = self.getVarScore(SCIPgetVarPseudocostScore(scip, var, lpcandssol[i]), SCIPgetAvgPseudocostScore(scip))


    # # col_feats['conflictscore'][0] = 999
    # # col_feats['conflengthscore'][0] = 999
    # # col_feats['inferencescore'][0] = 999
    # # col_feats['cutoffscore'][0] = 999

    # # implications 23
    # col_feats['implications_1'] = s['col']['implications_1'].reshape(-1, 1)
    # col_feats['implications_2'] = s['col']['implications_2'].reshape(-1, 1)
    # data:
    # col_implications_1[col_i] = SCIPvarGetNImpls(var, 0)
    # col_implications_2[col_i] = SCIPvarGetNImpls(var, 1)

    # # cliques 25
    # col_feats['cliques_1'] = s['col']['cliques_1'].reshape(-1, 1)
    # col_feats['cliques_2'] = s['col']['cliques_2'].reshape(-1, 1)
    # data:
    # col_cliques_1[col_i] = SCIPvarGetNCliques(var, 0) / SCIPgetNCliques(scip) 
    # col_cliques_2[col_i] = SCIPvarGetNCliques(var, 1) / SCIPgetNCliques(scip) 

    # # cutoffs 27
    # col_feats['cutoffs_1'] = s['col']['cutoffs_1'].reshape(-1, 1)
    # col_feats['cutoffs_2'] = s['col']['cutoffs_2'].reshape(-1, 1)
    # data:
    # col_cutoffs_1[col_i] = self.gNormMax(SCIPgetVarAvgCutoffsCurrentRun(scip, var, SCIP_BRANCHDIR_UPWARDS))
    # col_cutoffs_2[col_i] = self.gNormMax(SCIPgetVarAvgCutoffsCurrentRun(scip, var, SCIP_BRANCHDIR_DOWNWARDS))
            
    # # conflict length 29
    # col_feats['conflict_1'] = s['col']['conflict_1'].reshape(-1, 1)
    # col_feats['conflict_2'] = s['col']['conflict_2'].reshape(-1, 1)
    # data:
    # col_conflict_1[col_i] = self.gNormMax(SCIPgetVarAvgConflictlengthCurrentRun(scip, var, SCIP_BRANCHDIR_UPWARDS))
    # col_conflict_2[col_i] = self.gNormMax(SCIPgetVarAvgConflictlengthCurrentRun(scip, var, SCIP_BRANCHDIR_DOWNWARDS))
            
    # # inferences 32
    # col_feats['inferences_1'] = s['col']['inferences_1'].reshape(-1, 1)
    # col_feats['inferences_2'] = s['col']['inferences_2'].reshape(-1, 1)
    # # data:
    # col_inferences_1[col_i] = self.gNormMax(SCIPgetVarAvgInferencesCurrentRun(scip, var, SCIP_BRANCHDIR_UPWARDS))
    # col_inferences_2[col_i] = self.gNormMax(SCIPgetVarAvgInferencesCurrentRun(scip, var, SCIP_BRANCHDIR_DOWNWARDS))


    col_feat_names = [[k, ] if v.shape[1] == 1 else [f'{k}_{i}' for i in range(v.shape[1])] for k, v in col_feats.items()]
    col_feat_names = [n for names in col_feat_names for n in names]
    col_feat_vals = np.concatenate(list(col_feats.values()), axis=-1)

    # print("col_feat_vals:", len(col_feat_vals[0]))
    
    # for i,value in enumerate(col_feat_vals[0]):
    #     print("i:",i,"\tvalue:",value)

    # print("col_feat_vals:", len(col_feat_vals[0]))
    # print("col_feat_vals:", col_feat_vals[0])

    # assert 1==2

    variable_features = {
        'names': col_feat_names,
        'values': col_feat_vals,}

    # 行的功能 -> 约束特征
    # Row features

    if 'state' in buffer:
        row_feats = buffer['state']['row_feats']
        has_lhs = buffer['state']['has_lhs']
        has_rhs = buffer['state']['has_rhs']
    else:
        row_feats = {}
        has_lhs = np.nonzero(~np.isnan(s['row']['lhss']))[0]
        has_rhs = np.nonzero(~np.isnan(s['row']['rhss']))[0]
        row_feats['obj_cosine_similarity'] = np.concatenate((
            -s['row']['objcossims'][has_lhs],
            +s['row']['objcossims'][has_rhs])).reshape(-1, 1)
        row_feats['bias'] = np.concatenate((
            -(s['row']['lhss'] / row_norms)[has_lhs],
            +(s['row']['rhss'] / row_norms)[has_rhs])).reshape(-1, 1)

    row_feats['is_tight'] = np.concatenate((
        s['row']['is_at_lhs'][has_lhs],
        s['row']['is_at_rhs'][has_rhs])).reshape(-1, 1)

    row_feats['age'] = np.concatenate((
        s['row']['ages'][has_lhs],
        s['row']['ages'][has_rhs])).reshape(-1, 1) / (s['stats']['nlps'] + 5)

    # todo: check whether this change helps
    tmp = s['row']['dualsols'] / (row_norms * obj_norm)
    # tmp = s['row']['dualsols'] * row_norms / obj_norm
    row_feats['dualsol_val_normalized'] = np.concatenate((
            -tmp[has_lhs],
            +tmp[has_rhs])).reshape(-1, 1)

    row_feat_names = [[k, ] if v.shape[1] == 1 else [f'{k}_{i}' for i in range(v.shape[1])] for k, v in row_feats.items()]
    row_feat_names = [n for names in row_feat_names for n in names]
    row_feat_vals = np.concatenate(list(row_feats.values()), axis=-1)

    constraint_features = {
        'names': row_feat_names,
        'values': row_feat_vals,}

    # 边特征
    # Edge features
    if 'state' in buffer:
        edge_row_idxs = buffer['state']['edge_row_idxs']
        edge_col_idxs = buffer['state']['edge_col_idxs']
        edge_feats = buffer['state']['edge_feats']
    else:
        coef_matrix = sp.csr_matrix(
            (s['nzrcoef']['vals'] / row_norms[s['nzrcoef']['rowidxs']],
            (s['nzrcoef']['rowidxs'], s['nzrcoef']['colidxs'])),
            shape=(len(s['row']['ages']), len(s['col']['types'])))
        coef_matrix = sp.vstack((
            -coef_matrix[has_lhs, :],
            coef_matrix[has_rhs, :])).tocoo(copy=False)

        edge_row_idxs, edge_col_idxs = coef_matrix.row, coef_matrix.col
        edge_feats = {}

        edge_feats['coef_normalized'] = coef_matrix.data.reshape(-1, 1)

    edge_feat_names = [[k, ] if v.shape[1] == 1 else [f'{k}_{i}' for i in range(v.shape[1])] for k, v in edge_feats.items()]
    edge_feat_names = [n for names in edge_feat_names for n in names]
    edge_feat_indices = np.vstack([edge_row_idxs, edge_col_idxs])
    edge_feat_vals = np.concatenate(list(edge_feats.values()), axis=-1)

    edge_features = {
        'names': edge_feat_names,
        'indices': edge_feat_indices,
        'values': edge_feat_vals,}

    if 'state' not in buffer:
        buffer['state'] = {
            'obj_norm': obj_norm,
            'col_feats': col_feats,
            'row_feats': row_feats,
            'has_lhs': has_lhs,
            'has_rhs': has_rhs,
            'edge_row_idxs': edge_row_idxs,
            'edge_col_idxs': edge_col_idxs,
            'edge_feats': edge_feats,
        }
    
    state = constraint_features, edge_features, variable_features
    
    node_dim = 8
    mip_dim = 53
    my_node_state = model.getMyNodeState(node_dim)
    my_mip_state = model.getMyMIPState(mip_dim)
    my_node_state = {
        'names': 'my_node_state',
        'values': my_node_state,}
    my_mip_state = {
        'names': 'my_mip_state',
        'values': my_mip_state,}

    mystate = my_node_state, my_mip_state


    return mystate, state

def graph_transform_new_32(state):

    t_state, state = state
    node_state, mip_state = t_state
    node_state_values = node_state['values']
    mip_state_values = mip_state['values']
    tree_state = np.hstack([node_state_values,mip_state_values])

    # node_state_values.shape : (8,)
    # mip_state_values.shape : (53,)
    # tree_state.shape : (61,)

    idx = state[1]['indices']
    # make sure that feature shape matches the graph
    # 确保特征形状与图形匹配
    v2c_index = (idx[1, :], idx[0, :])
    c2v_index = (idx[0, :], idx[1, :])

    v_len = len(state[2]['values'])
    t2v_index = ( np.zeros((v_len,)), np.arange(0, v_len, 1) )
    t2v_edge_state = np.ones((v_len,1))

    # t2v_index = (np.zeros(len(node_edge)),node_edge[:,0] )
    # t2v_edge_state = np.ones((len(node_edge),1))

    # v2c_index : (v_id,c_id)
    # t2v_index : (t_id,v_id)


    # 获得边特征、约束特征、变量特征。
    edge_feats = torch.tensor(state[1]['values'], dtype=torch.float).view(-1, 1)
    c_feats = torch.tensor(state[0]['values'], dtype=torch.float)
    v_feats = torch.tensor(state[2]['values'], dtype=torch.float)

    t2v_edge_feats = torch.tensor(t2v_edge_state, dtype=torch.float).view(-1, 1)
    tree_feats = torch.tensor(tree_state, dtype=torch.float).view(1, -1)


    # 获得约束和变量的索引
    num_nodes_dict = {'t': 1, 'c': c_feats.shape[0], 'v': v_feats.shape[0]}
    
    # 构建异构图，一个异构图由一系列子图构成，一个子图对应一种关系。
    # 每个关系由一个字符串三元组定义（源节点类型，边类型，目标节点类型）。
    # ('t', 't2v', 'v'): 
    graph = dgl.heterograph({
        ('t', 't2v', 'v'): t2v_index,
        ('v', 'v2c', 'c'): v2c_index,
        ('c', 'c2v', 'v'): c2v_index,
        
    }, num_nodes_dict)

    # 给异构图赋值
    graph.edges['v2c'].data['h'] = edge_feats
    graph.edges['c2v'].data['h'] = edge_feats
    graph.nodes['v'].data['h'] = v_feats
    graph.nodes['c'].data['h'] = c_feats

    graph.edges['t2v'].data['h'] = t2v_edge_feats
    graph.nodes['t'].data['h'] = tree_feats

    # 返回异构图
    return graph

class BranchDataset_new_32(Dataset):
    def __init__(self, dirs):
        super().__init__()
        self.data = []
        for dir in dirs:
            sample = pickle.load(gzip.open(dir, 'rb'))['data']
            gcn_state, bestcand, action_set, scores, lp_scores_0, lp_scores_1 = sample
            
            data = gcn_state
            label = bestcand
            candidates = action_set
            scores = scores

            self.data.append((data, candidates, label, scores, lp_scores_0, lp_scores_1))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def dgl_collate_new_32(batch):
    data, candidates, label, scores, lp_scores_0, lp_scores_1 = map(list, zip(*batch))

    graphs = []
    for d in data:
        graphs.append(graph_transform_new_32(d))
    batched_graph = dgl.batch(graphs)
    label = torch.LongTensor(label)
    return batched_graph, candidates, label, scores, lp_scores_0, lp_scores_1

def train_new_32(args, model, trainData, optimizer, epoch, log_interval, device):
    model.train()
    batch_loss = 0.0
    total_count = 0.0
    correct_count = 0.0
    hard_correct_count = 0.0
    for batch_idx, (bg, candistaes, label, scores, lp_scores_0, lp_scores_1) in enumerate(trainData):
        optimizer.zero_grad()
        bg = bg.to(device)
        label = label.to(device)
        count = len(label)
        total_count += count
        bg = model(bg)
        
        graphs = dgl.unbatch(bg)

        if args.model_id == "sb":
            # 近似强分支分数
            logits = [g.nodes['v'].data['s'][candistaes[i]] for i, g in enumerate(graphs)]
            loss_ce = sum([F.cross_entropy(logits[i].T, label[i:i+1]) for i in range(count)]) / count
            loss_mse = sum([F.mse_loss(torch.FloatTensor(scores[i]).to(device), logits[i].T[0]) for i in range(count)]) / count
        
        elif args.model_id == "lp":
            # 近似中间变量
            logits_0 = [g.nodes['v'].data['s_0'][candistaes[i]] for i, g in enumerate(graphs)]
            logits_1 = [g.nodes['v'].data['s_1'][candistaes[i]] for i, g in enumerate(graphs)]

            logits = [torch.mul(logits_0[i],logits_1[i]) for i in range(count)]
            
            loss_0 = sum([F.mse_loss(torch.FloatTensor(lp_scores_0[i]).to(device), logits_0[i].squeeze()) for i in range(count)]) / count
            loss_1 = sum([F.mse_loss(torch.FloatTensor(lp_scores_1[i]).to(device), logits_1[i].squeeze()) for i in range(count)]) / count
            loss_mse = 0.5 * loss_0 + 0.5 * loss_1
            loss_ce = sum([F.cross_entropy(logits[i].T, label[i:i+1]) for i in range(count)]) / count

        elif args.model_id == 'all':
            logits_sb = [g.nodes['v'].data['s'][candistaes[i]] for i, g in enumerate(graphs)]
            logits_0 = [g.nodes['v'].data['s_0'][candistaes[i]] for i, g in enumerate(graphs)]
            logits_1 = [g.nodes['v'].data['s_1'][candistaes[i]] for i, g in enumerate(graphs)]

            logits_lp = [torch.mul(logits_0[i], logits_1[i]) for i in range(count)]

            logits = [0.5*logits_sb[i]+0.5*logits_lp[i] for i in range(count)]
            loss_ce = sum([F.cross_entropy(logits[i].T, label[i:i+1]) for i in range(count)]) / count
            # loss_mse = sum([F.mse_loss(torch.FloatTensor(scores[i]).to(device), logits[i].T[0]) for i in range(count)]) / count
            loss_mse_sb = sum([F.mse_loss(torch.FloatTensor(scores[i]).to(device), logits_sb[i].T[0]) for i in range(count)]) / count
            loss_mse_lp_1 = sum([F.mse_loss(torch.FloatTensor(lp_scores_0[i]).to(device), logits_0[i].squeeze()) for i in range(count)]) / count
            loss_mse_lp_2 = sum([F.mse_loss(torch.FloatTensor(lp_scores_1[i]).to(device), logits_1[i].squeeze()) for i in range(count)]) / count
            loss_mse = 0.5 * loss_mse_sb +  0.25 * loss_mse_lp_1 + 0.25 * loss_mse_lp_2
            
            loss_mse_sb_lp = sum([F.mse_loss(logits_sb[i], logits_lp[i]) for i in range(count)]) / count


        # loss
        if args.loss == "ce":
            loss = loss_ce
        elif args.loss == "mse":
            loss = loss_mse
        elif args.loss == "mse_sb_lp":
            loss = loss_mse_sb_lp
        elif args.loss == "ce_mse":
            loss = 0.5 * loss_mse + 0.5 * loss_ce
        elif args.loss == "ce_slmse":
            loss = 0.5 * loss_mse_sb_lp + 0.5 * loss_ce
        elif args.loss == "all":
            loss = loss_ce + 0.1 * loss_mse + 0.5 * loss_mse_sb_lp
        

        loss.backward()
        optimizer.step()
        batch_loss += loss.item() * count

        pred = [logits[i].argmax(dim=0) for i in range(count)]
        hard_correct_count += sum([pred[i] == label[i] for i in range(count)]).item()
        correct_count += sum([scores[i][pred[i]] == max(scores[i]) for i in range(count)])

        if (batch_idx+1) % log_interval == 0:
            logger.info('Train epoch {}:\t [{}/{}({:.0f}%)]\tTrain Loss: {:.6f}\t Accuracy: {:.6f}\t Hard Accuracy: {:.6f}'.format(
                epoch + 1, batch_idx+1, len(trainData), 100. * (batch_idx+1) / len(trainData), batch_loss / total_count,
            correct_count / total_count, hard_correct_count/total_count))

def test_new_32(args, model, testData, optimizer, epoch, device):
    model.eval()
    batch_loss = 0
    correct_count = 0
    total_count = 0
    hard_correct_count = 0
    with torch.no_grad():
        for bg, candistaes, label, scores, lp_scores_0, lp_scores_1 in testData:
            optimizer.zero_grad()
            bg = bg.to(device)
            label = label.to(device)
            count = len(label)
            total_count += count
            bg = model(bg)
            graphs = dgl.unbatch(bg)
            
            if args.model_id == "sb":
                # 近似强分支分数
                logits = [g.nodes['v'].data['s'][candistaes[i]] for i, g in enumerate(graphs)]
                loss_ce = sum([F.cross_entropy(logits[i].T, label[i:i+1]) for i in range(count)]) / count
                loss_mse = sum([F.mse_loss(torch.FloatTensor(scores[i]).to(device), logits[i].T[0]) for i in range(count)]) / count
            
            elif args.model_id == "lp":
                # 近似中间变量
                logits_0 = [g.nodes['v'].data['s_0'][candistaes[i]] for i, g in enumerate(graphs)]
                logits_1 = [g.nodes['v'].data['s_1'][candistaes[i]] for i, g in enumerate(graphs)]
                # logits = [torch.mul(torch.clamp(logits_0[i], min=1.0/6.0),torch.clamp(logits_1[i], min=1.0/6.0)) for i in range(count)]
                logits = [torch.mul(logits_0[i],logits_1[i]) for i in range(count)]

                loss_0 = sum([F.mse_loss(torch.FloatTensor(lp_scores_0[i]).to(device), logits_0[i].squeeze()) for i in range(count)]) / count
                loss_1 = sum([F.mse_loss(torch.FloatTensor(lp_scores_1[i]).to(device), logits_1[i].squeeze()) for i in range(count)]) / count
                loss_mse = 0.5 * loss_0 + 0.5 * loss_1
                loss_ce = sum([F.cross_entropy(logits[i].T, label[i:i+1]) for i in range(count)]) / count

            elif args.model_id == 'all':
                logits_sb = [g.nodes['v'].data['s'][candistaes[i]] for i, g in enumerate(graphs)]
                logits_0 = [g.nodes['v'].data['s_0'][candistaes[i]] for i, g in enumerate(graphs)]
                logits_1 = [g.nodes['v'].data['s_1'][candistaes[i]] for i, g in enumerate(graphs)]

                logits_lp = [torch.mul(logits_0[i],logits_1[i]) for i in range(count)]

                logits = [0.5*logits_sb[i]+0.5*logits_lp[i] for i in range(count)]
                

                loss_ce = sum([F.cross_entropy(logits[i].T, label[i:i+1]) for i in range(count)]) / count
                # loss_mse = sum([F.mse_loss(torch.FloatTensor(scores[i]).to(device), logits[i].T[0]) for i in range(count)]) / count
                loss_mse_sb = sum([F.mse_loss(torch.FloatTensor(scores[i]).to(device), logits_sb[i].T[0]) for i in range(count)]) / count
                loss_mse_lp_1 = sum([F.mse_loss(torch.FloatTensor(lp_scores_0[i]).to(device), logits_0[i].squeeze()) for i in range(count)]) / count
                loss_mse_lp_2 = sum([F.mse_loss(torch.FloatTensor(lp_scores_1[i]).to(device), logits_1[i].squeeze()) for i in range(count)]) / count
                loss_mse = 0.5 * loss_mse_sb +  0.25 * loss_mse_lp_1 + 0.25 * loss_mse_lp_2
                
                loss_mse_sb_lp = sum([F.mse_loss(logits_sb[i], logits_lp[i]) for i in range(count)]) / count


            # loss
            if args.loss == "ce":
                loss = loss_ce
            elif args.loss == "mse":
                loss = loss_mse
            elif args.loss == "mse_sb_lp":
                loss = loss_mse_sb_lp
            elif args.loss == "ce_mse":
                loss = 0.5 * loss_mse + 0.5 * loss_ce
            elif args.loss == "ce_slmse":
                loss = 0.5 * loss_mse_sb_lp + 0.5 * loss_ce
            elif args.loss == "all":
                loss = loss_ce + 0.1 * loss_mse + 0.5 * loss_mse_sb_lp
            

            batch_loss += loss.item() * count

            pred = [logits[i].argmax(dim=0) for i in range(count)]
            hard_correct_count += sum([pred[i] == label[i] for i in range(count)]).item()
            correct_count += sum([scores[i][pred[i]] == max(scores[i]) for i in range(count)])


    loss = batch_loss / total_count
    accuracy = correct_count / total_count
    hard_acc = hard_correct_count / total_count
    logger.info('Test epoch {}:\t Test Loss: {:.6f}\t Accuracy: {:.6f}\t Hard Accuracy: {:.6f}'.format(epoch + 1, loss, accuracy, hard_acc))
    return accuracy
