import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from resnet_models import * 
import torch.backends.cudnn as cudnn
import random 
import torch.optim as optim
import argparse
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
import torchvision
from random import choice




def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)
    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2



def cutmix(model, criterion, inputs, targets, cutmix_prob = 1.0, beta = 1.0):
    r = np.random.rand(1)
    if beta > 0 and r < cutmix_prob:
      # generate mixed sample
      lam = np.random.beta(beta, beta)
      rand_index = torch.randperm(inputs.size()[0]).cuda()
      target_a = targets
      target_b = targets[rand_index]
      bbx1, bby1, bbx2, bby2 = rand_bbox(inputs.size(), lam)
      inputs[:, :, bbx1:bbx2, bby1:bby2] = inputs[rand_index, :, bbx1:bbx2, bby1:bby2]
      # adjust lambda to exactly match pixel ratio
      lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (inputs.size()[-1] * inputs.size()[-2]))
      # compute output
      output = model(inputs)
      loss = criterion(output, target_a) * lam + criterion(output, target_b) * (1. - lam)

    return output, loss


def cutmix_fixed(model, criterion, inputs, targets, cutmix_prob = 1.0, beta = 1.0):
    r = np.random.rand(1)
    if beta > 0 and r < cutmix_prob:
      # generate mixed sample
      lam = np.random.beta(beta, beta)
      rand_index = torch.randperm(inputs.size()[0]).cuda()
      target_a = targets
      target_b = targets[rand_index]
      bbx1, bby1, bbx2, bby2 = rand_bbox(inputs.size(), lam)
      inputs[:, :, bbx1:bbx2, bby1:bby2] = inputs[rand_index, :, bbx1:bbx2, bby1:bby2]
      # adjust lambda to exactly match pixel ratio
      lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (inputs.size()[-1] * inputs.size()[-2]))
      # compute output
      output = model(inputs)
      loss = criterion(output, target_a)

    return output, loss




def mixup(model, criterion, inputs, targets, alpha=1.0):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = inputs.size()[0]
    index = torch.randperm(batch_size).cuda()
    mixed_input = lam * inputs + (1 - lam) * inputs[index, :]
    target_a, target_b = targets, targets[index]
    output = model(mixed_input)

    loss = lam * criterion(output, target_a) + (1. - lam) * criterion(output, target_b)
    return output, loss


def get_augmentation(augmentation):

    if augmentation == 'mixup':
        return mixup

    elif augmentation == 'cutmix':
        return cutmix

    elif augmentation == 'cutmix_fixed':
        return cutmix_fixed
