import os
import re
from collections import OrderedDict

import cv2
import numpy as np
import PIL
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.nn.modules.loss import _Loss
from torch.optim.lr_scheduler import _LRScheduler
from torchvision import transforms


class AvgrageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.cnt = 0
        self.val = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.cnt += n
        self.avg = self.sum / self.cnt


def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def save_checkpoint(model, iters, path, optimizer=None, scheduler=None):
    if not os.path.exists(path):
        os.makedirs(path)
    print("Saving checkpoint to file {}".format(path))
    state_dict = {}
    state_dict["model"] =  model.state_dict()
    state_dict["iteration"] = iters
    if optimizer is not None:
        state_dict["optimizer"] = optimizer.state_dict()
    if scheduler is not None:
        state_dict["scheduler"] = scheduler.state_dict()

    filename = os.path.join("{}/checkpoint.pth".format(path))
    try:
        torch.save(state_dict, filename)
    except OSError:
        print("save {} failed, continue training".format(path))

def compute_au(w, d, lr):
    w_n = torch.norm(w).item()
    d_n = torch.norm(d).item()
    cos = F.cosine_similarity(
        w.view(1, -1),
        d.view(1, -1),
        dim=1,
    ).item()
    cos = max(-1, min(1, cos))
    sin = (1-cos**2)**0.5
    au = sin * d_n * lr / w_n

    return au, w_n, d_n, cos