import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

import shutil
import random
import numpy as np
from enum import Enum
from pathlib import Path


def set_random_seed(seed_num=1):
	random.seed(seed_num)
	np.random.seed(seed_num)
	torch.manual_seed(seed_num)
	torch.cuda.manual_seed(seed_num)
	torch.cuda.manual_seed_all(seed_num)
	torch.backends.cudnn.deterministic = True
	torch.backends.cudnn.benchmark = False


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12356'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


def save_checkpoint(state, is_best, dirname='result', filename='checkpoint.pth'):
    if not is_best:
        filepath = Path(dirname) / filename
        torch.save(state, filepath)
    else:
        name, ext = Path(filename).stem, Path(filename).suffix
        best_filename = ''.join([name, '_best', ext])
        best_filepath = Path(dirname) / best_filename
        torch.save(state, best_filepath)


def lr_decay(optimizer, lr, decay_rate=.2):
    lr = lr * decay_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return lr

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def mixup_data(x, y, alpha=1.0, device=None):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if device is None:
        index = torch.randperm(batch_size)
    else:
        index = torch.randperm(batch_size).to(device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_process(out, target_reweighted, lam):
    indices = np.random.permutation(out.size(0))
    out = out * lam + out[indices]*(1-lam)
    target_shuffled_onehot = target_reweighted[indices]
    target_reweighted = target_reweighted * lam + target_shuffled_onehot * (1 - lam)
    return out, target_reweighted

class TwoCropTransform:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return [self.transform(x), self.transform(x)]