from pathlib import Path
from typing import Optional

import comet_ml
import numpy as np
import torch


def calc_accuracy(logits, labels):
    mask = (logits > 0.5).long()
    correct = torch.sum(mask * labels + (1 - mask) * (1 - labels))
    return correct / len(logits)


def sigmoid_cross_entropy_with_logits(logits, labels):
    x = logits.unsqueeze(1)
    x = torch.cat((x, torch.zeros_like(x, device=x.device)), dim=-1)
    x, _ = torch.max(x, dim=1)
    x = x - logits * labels + torch.log(1 + torch.exp(-torch.abs(logits)))
    return torch.mean(x)


def logsigmoid(a):
    return -torch.nn.Softplus()(-a)


def logit_bernoulli_entropy(logits):
    ent = (1. - torch.sigmoid(logits)) * logits - logsigmoid(logits)
    return torch.mean(ent)


def mse_loss(source: np.ndarray, target: np.ndarray):
    return ((source - target)**2).mean()


def calc_alignment_score(target_states,
                         source_states_hat,
                         apply_shift: bool = False,
                         device: str = 'cuda:0'):
    target_states = target_states.to(device)
    inverted_states = target_states[..., [1, 0, 3, 2]]
    if apply_shift:
        shift = torch.tensor([6, 6, 0, 0],
                             device=inverted_states.device,
                             dtype=torch.float)
        inverted_states -= shift

    align_score = torch.nn.MSELoss()(inverted_states, source_states_hat).item()
    return align_score, inverted_states


def plot_alignment(
    true_source_states,
    predicted_source_states,
    logdir: Path,
    epoch: int,
    n_plot: int = 1000,
    experiment: Optional[comet_ml.Experiment] = None,
):
    import matplotlib.pyplot as plt
    import numpy as np
    plt.clf()
    true_source_states = np.vstack(true_source_states)[:n_plot]
    predicted_source_states = np.vstack(predicted_source_states)[:n_plot]
    concat_states = np.concatenate(
        (true_source_states, predicted_source_states))
    state_min, state_max = concat_states.min(0), concat_states.max(0)
    plt.xlim([state_min[0], state_max[0]])
    plt.ylim([state_min[1], state_max[1]])

    true_source_pos = true_source_states[..., :2]
    predicted_source_pos = predicted_source_states[..., :2]
    plt.scatter(*true_source_pos.T, label='source', marker='.')
    plt.scatter(*predicted_source_pos.T, label='predicted', marker='.')

    distance = mse_loss(true_source_states, predicted_source_states)
    for k in range(15):
        first = true_source_pos[k]
        second = predicted_source_pos[k]
        plt.plot((first[0], second[0]), (first[1], second[1]),
                 color='black',
                 lw=1)
    plt.legend()
    plt.title(f'(All) {logdir}\n epoch {epoch}: dist={distance:.5f}')
    plt.tight_layout()
    png_path = logdir / 'all-z.png'
    plt.savefig(png_path)
    if experiment:
        experiment.log_image(str(png_path), step=epoch)
