import torch.nn as nn
import torch.nn.functional as F
import torch
import losses as L
import numpy as np
import torchvision.transforms.functional as TF
from torch.autograd import Variable
import random
from kornia import augmentation

def kdloss(y, teacher_scores, temperature=20):
    p = F.log_softmax(y/ temperature, dim=1)
    q = F.softmax(teacher_scores/ temperature, dim=1)
    l_kl = F.kl_div(p, q, size_average=False)* (temperature ** 2) / y.shape[0]
    return l_kl


def adjust_learning_rate(optimizer, epoch, learing_rate):
    if epoch < 1600:
        lr = learing_rate
    elif epoch < 3200:
        lr = 0.1 * learing_rate
    else:
        lr = 0.01 * learing_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def adjust_learning_rate_200(optimizer, epoch, learing_rate):
    if epoch < 200:
        lr = learing_rate
    elif epoch < 400:
        lr = 0.1 * learing_rate
    else:
        lr = 0.01 * learing_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

class AvgrageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.cnt = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.cnt += n
        self.avg = self.sum / self.cnt


def accuracy(output, target, topk=(1, )):
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res



def getgrad(net):
    g=[]
    for name, param in net.named_parameters():
        if param.requires_grad:
            if param.grad is not None:
                #print(type(torch.tensor(param.grad).view(-1)))
                g.append(torch.tensor(param.grad).view(-1).contiguous())
                #param.grad.data.zero_()
    g = torch.cat(g, dim=0).detach()
    #print(type(g))
    return g

def harmonicgradloss(loss_localization, loss_classification, net, opt):
    opt.zero_grad()
    #print('1')
    loss_localization.backward(retain_graph=True)
    #print('2')
    g1 = getgrad(net)
    #print(g1.size())
    opt.zero_grad()
    loss_classification.backward(retain_graph=True)
    #print('4')
    g2 = getgrad(net)
    opt.zero_grad()
    inner_product = torch.dot(g1, g2)
    delta = 1 if inner_product < 0 else 0
    g1_l2 = torch.dot(g1, g1)
    g2_l2 = torch.dot(g2, g2)
    #print(inner_product)
    #print(g1_l2)
    #print(g2_l2)
    #print('-----')
    loss_harmonicgrad = loss_localization + loss_classification - delta * ((inner_product / g1_l2) * loss_localization + (inner_product / g2_l2) * loss_classification)
    return loss_harmonicgrad


def generate_hee(model, x):
    steps_hee = 10
    lr_hee = 0.03

    device = x.device

    model.eval()
    x_hee = x.detach() + 0.001 * torch.torch.randn(x.shape).to(device).detach()
    for _ in range(steps_hee):
        x_hee.requires_grad_()
        with torch.enable_grad():
            loss = L.Entropy_Loss(reduction="mean")(model(x_hee))
        grad = torch.autograd.grad(loss, [x_hee])[0]
        x_hee = x_hee.detach() + lr_hee * torch.sign(grad.detach())
        x_hee = torch.clamp(x_hee, 0.0, 1.0)
    model.train()
    return x_hee



def generate_hee_l1(model, x, teacher_rob):
    steps_hee = 10
    lr_hee = 0.03

    device = x.device

    model.eval()
    x_hee = x.detach() + 0.001 * torch.torch.randn(x.shape).to(device).detach()
    for _ in range(steps_hee):
        x_hee.requires_grad_()
        with torch.enable_grad():
            loss = L.Entropy_Loss(reduction="mean")(model(x_hee)) + 0.3 * nn.L1Loss()(model(x_hee), teacher_rob(x_hee))
        grad = torch.autograd.grad(loss, [x_hee])[0]
        x_hee = x_hee.detach() + lr_hee * torch.sign(grad.detach())
        x_hee = torch.clamp(x_hee, 0.0, 1.0)
    model.train()
    return x_hee


def generate_l1(x, net, teacher_rob):
    steps_hee = 10
    lr_hee = 0.03

    device = x.device

    net.eval()
    x_hee = x.detach() + 0.001 * torch.torch.randn(x.shape).to(device).detach()
    for _ in range(steps_hee):
        x_hee.requires_grad_()
        with torch.enable_grad():
            # loss = L.Entropy_Loss(reduction="mean")(model(x_hee))
            loss = nn.L1Loss()(net(x_hee), teacher_rob(x_hee))
        grad = torch.autograd.grad(loss, [x_hee])[0]
        x_hee = x_hee.detach() + lr_hee * torch.sign(grad.detach())
        x_hee = torch.clamp(x_hee, 0.0, 1.0)
    net.train()

    return x_hee


def get_xadv(img, teacher_nat, teacher_rob):
    eps = 0.3
    clip_min = 0
    clip_max = 1.0
    img = img.cuda()
    # x_new = img + torch.Tensor(np.random.uniform(eps, eps, img.shape)).type_as(img).cuda()
    img = Variable(img, requires_grad=True)
    # print(img.shape)

    # outputs = model(img)
    # targets = pid.cuda()
    # criterion = nn.CrossEntropyLoss()
    # # print(img)
    # if flag_target:
    #     loss = -criterion(outputs, targets)
    # else:
    #     loss = criterion(outputs, targets)

    outputs_nat = teacher_nat(img)
    outputs_rob = teacher_rob(img)
    criterion = nn.L1Loss()
    loss = criterion(outputs_rob, outputs_nat)

    teacher_nat.zero_grad()
    loss.backward()
    # print('!!!!!!!', type(img))
    grad = img.grad.cpu().detach().numpy()
    grad = np.sign(grad)
    pertubation = grad * eps
    adv_x = img.cpu().detach().numpy() + pertubation
    # adv_x = np.clip(adv_x, clip_min, clip_max)
    adv_x = torch.from_numpy(adv_x)
    pertubation = torch.from_numpy(pertubation)
    return adv_x.cuda(), pertubation



def strong_aug(image):
    device = image.device
    image = TF.center_crop(
        image,
        [int(32.0 * random.uniform(0.95, 1.0)), int(32.0 * random.uniform(0.95, 1.0))],
    )
    image = TF.resize(image, [32, 32])
    noise = torch.randn_like(image).to(device) * 0.001
    image = torch.clamp(image + noise, 0.0, 1.0)
    if random.uniform(0, 1) > 0.5:
        image = TF.vflip(image)
    if random.uniform(0, 1) > 0.5:
        image = TF.hflip(image)
    angles = [-15, 0, 15]
    angle = random.choice(angles)
    image = TF.rotate(image, angle)
    return image


def strong_aug_224(image):
    device = image.device
    image = TF.center_crop(
        image,
        [int(224.0 * random.uniform(0.95, 1.0)), int(224.0 * random.uniform(0.95, 1.0))],
    )
    image = TF.resize(image, [224, 224])
    noise = torch.randn_like(image).to(device) * 0.001
    image = torch.clamp(image + noise, 0.0, 1.0)
    if random.uniform(0, 1) > 0.5:
        image = TF.vflip(image)
    if random.uniform(0, 1) > 0.5:
        image = TF.hflip(image)
    angles = [-15, 0, 15]
    angle = random.choice(angles)
    image = TF.rotate(image, angle)
    return image


img_shape = (3, 32, 32)
std_aug = augmentation.container.ImageSequential(
    augmentation.RandomCrop(size=[img_shape[-2], img_shape[-1]], padding=4),
    augmentation.RandomHorizontalFlip(),
)


def generate_hee_03(model, x):
    steps_hee = 10
    lr_hee = 0.3

    device = x.device

    model.eval()
    x_hee = x.detach() + 0.001 * torch.torch.randn(x.shape).to(device).detach()
    for _ in range(steps_hee):
        x_hee.requires_grad_()
        with torch.enable_grad():
            loss = L.Entropy_Loss(reduction="mean")(model(x_hee))
        grad = torch.autograd.grad(loss, [x_hee])[0]
        x_hee = x_hee.detach() + lr_hee * torch.sign(grad.detach())
        x_hee = torch.clamp(x_hee, 0.0, 1.0)
    model.train()

    return x_hee