import os
import torch
import numpy as np

# cifar10 classes
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck']


def logmeanexp(x, dim=None, keepdim=False):
    """Stable computation of log(mean(exp(x))"""

    
    if dim is None:
        x, dim = x.view(-1), 0
    x_max, _ = torch.max(x, dim, keepdim=True)
    x = x_max + torch.log(torch.mean(torch.exp(x - x_max), dim, keepdim=True))
    return x if keepdim else x.squeeze(dim)

# check if dimension is correct

# def dimension_check(x, dim=None, keepdim=False):
#     if dim is None:
#         x, dim = x.view(-1), 0

#     return x if keepdim else x.squeeze(dim)


def adjust_learning_rate(optimizer, lr):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def save_array_to_file(numpy_array, filename):
    file = open(filename, 'a')
    shape = " ".join(map(str, numpy_array.shape))
    np.savetxt(file, numpy_array.flatten(), newline=" ", fmt="%.3f")
    file.write("\n")
    file.close()


def load_model(net,path):
    net.load_state_dict(torch.load(path))
    net.eval()



def write_list(folder,results_file,list):
    if not os.path.exists(str(folder)):
        os.mkdir(str(folder))

    with open(str(folder) + '/'+ str(results_file), 'a') as f:
        #f.write("Current list models : " + "\n")
        for row in list:
            f.write(str(row) + '\n')
        f.close()

def write_list2(folder,results_file,list):
    with open(str(folder) + '/'+ str(results_file), 'a') as f:
        #f.write("Current list models : " + "\n")
        for row in list:
            f.write(str(row) + ',')
        f.write('\n')
        f.close()


def write_line(folder,results_file,line):
    if not os.path.exists(str(folder)):
        os.mkdir(str(folder))

    with open(str(folder) + '/'+ str(results_file), 'a') as f:
        f.write( str(line)+"\n")
        f.close()


def write_line2(folder,results_file,line):
    if not os.path.exists(str(folder)):
        os.mkdir(str(folder))

    with open(str(folder) + '/'+ str(results_file), 'a') as f:
        f.write( str(line)+",")
        f.close()


def write_matrix(folder,results_file,mat):
    if not os.path.exists(str(folder)):
        os.mkdir(str(folder))

    with open(str(folder) +'/'+ results_file, 'a') as f:
        for line in mat:
            np.savetxt(f, line, fmt='%.2f')
        f.close()


def str2bool(v):
    return v.lower() in ("yes", "true", "t", "1")