import torch
import random
import numpy as np
import os
import pandas as pd
import seaborn as sn
import matplotlib.pyplot as plt

bceWithLogitsLoss = torch.nn.BCEWithLogitsLoss()

def DSL_bceloss(predict, target):
    chosen = torch.argmax(predict, dim=1, keepdim=True).detach()
    chosen_predict = torch.gather(predict, dim=1, index=chosen)
    chosen_target = torch.gather(target, dim=1, index=chosen)
    accuracy = torch.mean(chosen_target).item()
    chosen_predict = torch.logit(chosen_predict, 0.0001)
    return bceWithLogitsLoss(chosen_predict, chosen_target), accuracy

def get_accuracy(predict, target):
    chosen = torch.argmax(predict, dim=1, keepdim=True).detach()
    chosen_target = torch.gather(target, dim=1, index=chosen)
    accuracy = torch.mean(chosen_target).item()
    return accuracy

def seedBasic(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    
def seedTorch(seed=42):
    torch.manual_seed(seed)
    try:
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    except Exception as e:
        pass
    
def seed(seed=42):
    seedBasic(seed)
    seedTorch(seed)

def visualize_confusion(confusion, name):
    df_cm = pd.DataFrame(confusion, index=np.arange(10), columns=np.arange(10))
    plt.figure(figsize=(5, 5), dpi=128)
    ax = sn.heatmap(df_cm, annot=True, cmap='Oranges', fmt='d', cbar=False)
    ax.set_xticks(np.arange(0.5, 10., 1.))
    ax.set_yticks(np.arange(0.5, 10., 1.))
    ax.set_yticklabels(np.arange(10))
    ax.set_xticklabels(np.arange(10))
    ax.set_title(name)
    ax.set_xlabel('predicted label')
    ax.set_ylabel('original MNIST label')
    ax.tick_params(axis='both', which='both', length=3)
    plt.savefig('figures/{}_confusion.pdf'.format(name))
    plt.show()