import gc
import numpy as np
import os
import pickle
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Subset,DataLoader




def save_results(args, filename, results):
    if not os.path.exists(args.results_folder):
        os.makedirs(args.results_folder)

    experiment_folder = '{}-{}-n{}-bs{}-{}-alpha-{}'.format(args.arch, args.data, args.num_samples_per_class, args.batch_size,args.optimizer,args.alpha)
    experiment_folder = os.path.join(args.results_folder, experiment_folder)

    if not os.path.exists(experiment_folder):
        os.makedirs(experiment_folder)

        
    filename = filename+'.pkl'
    filepath = os.path.join(experiment_folder, filename)

    file_folder = os.path.dirname(os.path.abspath(filepath))
    

    if not os.path.exists(file_folder):
        os.makedirs(file_folder)
    
    with open(filepath, 'wb') as fh:
        pickle.dump(results, fh)




def save_torch_results(args, filename, results):

    if not os.path.exists(args.results_folder):
        os.makedirs(args.results_folder)

    experiment_folder = '{}-{}-n{}-bs{}-{}-alpha-{}'.format(args.arch, args.data, args.num_samples_per_class, args.batch_size,args.optimizer,args.alpha)
    experiment_folder = os.path.join(args.results_folder, experiment_folder)

    if not os.path.exists(experiment_folder):
        os.makedirs(experiment_folder)
       
    
    filename = filename+'.t7'
    filepath = os.path.join(experiment_folder, filename)

    file_folder = os.path.dirname(os.path.abspath(filepath))
    

    if not os.path.exists(file_folder):
        os.makedirs(file_folder)

    torch.save(results, filepath)





def load_resutls(args, filename):

    filename = filename+'.pkl'
    filepath = os.path.join(args.results_folder, filename)

    with open(filepath, 'rb') as f:
        data = pickle.load(f,encoding='bytes')


    return data
        





########### evaluate test data #####################

def test_eval(dataloader, model, criterion, device):
    """Eval on the test set"""
    losses = 0
    correct = 0
    count = 0
    # switch to evaluate mode
    model.eval()
    for i,(images,labels) in enumerate(dataloader):
        images = images.to(device)
        labels = labels.to(device)

        # compute output
        outputs = model(images)
        loss = criterion(outputs, labels)

        # measure accuracy and record loss
        losses+=loss.item()
        # Track the accuracy
        total = labels.size(0)
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == labels).sum().item()/total

        count+=1

        del images,labels
        gc.collect()
    return losses/(i+1), correct/count



def adjust_beta(optimizer, t, beta0=55000): # t: the iteration number
    """Sets the learning rate to the initial LR decayed by 2 every 30 epochs"""
    beta = np.maximum(np.exp(t/100), beta0)

    for param_group in optimizer.param_groups:
        param_group['beta'] = beta


def train_val_dataset(dataset, m=100):
    train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=m, random_state=42)
    train_dataset= Subset(dataset, train_idx)
    val_dataset = Subset(dataset, val_idx)
    return val_dataset,train_dataset
