import numpy as np
import math


import torch.nn.functional as F
import torch

from ray import tune


def possibly_gridsearch(arg):
    if type(arg) is list:
        return tune.grid_search(arg)
    else:
        return arg


def angle_between(p1, p2, fang):
    ang1 = np.arctan2(p2[0] - p1[0], p2[1] - p1[1])
    ang1 = (360 - np.rad2deg(ang1) + 90) % 360
    diff = np.rad2deg(fang) - ang1
    if diff > 180:
        diff -= 360
    if diff < -180:
        diff += 360
    return np.deg2rad(diff)


def distance_between(p1, p2):
    return math.sqrt((p2[0] - p1[0]) ** 2 + (p2[1] - p1[1]) ** 2)


def svd_sol(A, b):
    U, sigma, Vt = np.linalg.svd(A)
    sigma[sigma < 1e-10] = 0
    sigma_reci = [(1 / s if s != 0 else 0) for s in sigma]
    sigma_reci = np.diag(sigma_reci)
    x = Vt.transpose().dot(sigma_reci).dot(U.transpose()).dot(b)
    return x


def can_shoot(agent_p_pos, agent_p_ang, opp_p_pos):
    shootRad = 0.8
    shootWin = np.pi / 5
    size = 0.05
    ang = agent_p_ang
    pt1 = agent_p_pos + size * np.array([np.cos(ang), np.sin(ang)])
    pt2 = pt1 + shootRad * np.array(
        [np.cos(ang + shootWin / 2), np.sin(ang + shootWin / 2)]
    )
    pt3 = pt1 + shootRad * np.array(
        [np.cos(ang - shootWin / 2), np.sin(ang - shootWin / 2)]
    )
    A = np.array([[pt1[0], pt2[0], pt3[0]], [pt1[1], pt2[1], pt3[1]], [1, 1, 1]])

    b = np.array([[opp_p_pos[0]], [opp_p_pos[1]], [1]])
    x = svd_sol(A, b)
    if np.all(x >= 0):
        return True
    else:
        return False


def init(module, weight_init, bias_init, gain=1):
    weight_init(module.weight.data, gain=gain)
    bias_init(module.bias.data)
    return module


"""Pytorch implementation of Class-Balanced-Loss
   Reference: "Class-Balanced Loss Based on Effective Number of Samples" 
   Authors: Yin Cui and
               Menglin Jia and
               Tsung Yi Lin and
               Yang Song and
               Serge J. Belongie
   https://arxiv.org/abs/1901.05555, CVPR'19.
"""


def focal_loss(labels, logits, alpha, gamma):
    """Compute the focal loss between `logits` and the ground truth `labels`.
    Focal loss = -alpha_t * (1-pt)^gamma * log(pt)
    where pt is the probability of being classified to the true class.
    pt = p (if true class), otherwise pt = 1 - p. p = sigmoid(logit).
    Args:
      labels: A float tensor of size [batch, num_classes].
      logits: A float tensor of size [batch, num_classes].
      alpha: A float tensor of size [batch_size]
        specifying per-example weight for balanced cross entropy.
      gamma: A float scalar modulating loss from hard and easy examples.
    Returns:
      focal_loss: A float32 scalar representing normalized total loss.
    """
    BCLoss = F.binary_cross_entropy_with_logits(
        input=logits, target=labels, reduction="none"
    )

    if gamma == 0.0:
        modulator = 1.0
    else:
        modulator = torch.exp(
            -gamma * labels * logits - gamma * torch.log(1 + torch.exp(-1.0 * logits))
        )

    loss = modulator * BCLoss

    weighted_loss = alpha * loss
    focal_loss = torch.sum(weighted_loss)

    focal_loss /= torch.sum(labels)
    return focal_loss


def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gamma):
    """Compute the Class Balanced Loss between `logits` and the ground truth `labels`.
    Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)
    where Loss is one of the standard losses used for Neural Networks.
    Args:
      labels: A int tensor of size [batch].
      logits: A float tensor of size [batch, no_of_classes].
      samples_per_cls: A python list of size [no_of_classes].
      no_of_classes: total number of classes. int
      loss_type: string. One of "sigmoid", "focal", "softmax".
      beta: float. Hyperparameter for Class balanced loss.
      gamma: float. Hyperparameter for Focal loss.
    Returns:
      cb_loss: A float tensor representing class balanced loss
    """
    effective_num = 1.0 - torch.pow(beta, samples_per_cls)

    weights = (1.0 - beta) / (effective_num + 1e-3)

    weights = weights / torch.sum(weights) * no_of_classes

    labels_one_hot = F.one_hot(labels, no_of_classes).float()

    weights = weights.float()
    weights = weights.unsqueeze(0)
    weights = weights.repeat(labels_one_hot.shape[0], 1) * labels_one_hot
    weights = weights.sum(1)
    weights = weights.unsqueeze(1)
    weights = weights.repeat(1, no_of_classes)

    if loss_type == "focal":
        cb_loss = focal_loss(labels_one_hot, logits, weights, gamma)
    elif loss_type == "sigmoid":
        cb_loss = F.binary_cross_entropy_with_logits(
            input=logits, target=labels_one_hot, pos_weight=weights
        )
    elif loss_type == "softmax":
        pred = logits.softmax(dim=1)
        cb_loss = F.binary_cross_entropy(
            input=pred, target=labels_one_hot, weight=weights
        )
    return cb_loss
