import math
import torch
from modules import MLP
from scipy.spatial import distance


def get_rectangular_data(n_data, xmin, xmax, ymin, ymax):
    data = torch.rand(size=(n_data, 2))
    data *= torch.tensor([[xmax-xmin, ymax-ymin]])
    data += torch.tensor([[xmin, ymin]])
    return data


def get_round_data(n_data, radius, center_x, center_y):
    r = torch.rand(size=(n_data,)).sqrt() * radius
    theta = torch.rand(size=(n_data,)) * 2 * torch.pi
    data = torch.stack([
        r * torch.cos(theta) + center_x, 
        r * torch.sin(theta) + center_y
    ], dim=1)
    return data


def get_2d_tasks(data_per_task, scaler=1.0, device='cpu', include_supp=False, make_less_separable=False):
    """ Get 3 binary classification tasks whose joint max-margin direction is (1,1). """
    Xs, ys = [], []
    add =  1 if include_supp else 0

    # Task 1:
    # +1: circle with radius 3 center [2, 15]
    # -1: rectangle with x[0,5]y[-13,-9]
    n_pos = (torch.rand(size=(data_per_task-add,)) < 0.5).sum().item()
    n_neg = data_per_task - n_pos
    X_pos = get_round_data(n_pos, 3, 2, 15)
    if make_less_separable:
        X_pos -= torch.tensor([[4.,4.]])
    X_neg = get_rectangular_data(n_neg-add, 0, 5, -13, -9)
    if include_supp:
        X_neg = torch.cat([X_neg, torch.tensor([5., -9.])[None, :]], dim=0)
    assert n_pos == X_pos.size(0)
    assert n_neg == X_neg.size(0)
    Xs.append(torch.cat([X_pos, X_neg], dim=0).to(device) * scaler)
    ys.append(torch.cat([torch.ones(n_pos).long(), -torch.ones(n_neg).long()])[:, None].to(device))
    

    # Task 2:
    # +1: circle with radius 2.5 center [17, 0]
    # -1: rectangle with x[-14,-7]y[-3,3]
    n_pos = (torch.rand(size=(data_per_task-add,)) < 0.5).sum().item()
    n_neg = data_per_task - n_pos
    X_pos = get_round_data(n_pos, 2.5, 17, 0)
    if make_less_separable:
        X_pos -= torch.tensor([[4.,4.]])
    X_neg = get_rectangular_data(n_neg-add, -14, -7, -3, 3)
    if include_supp:
        X_neg = torch.cat([X_neg, torch.tensor([-7., 3.])[None, :]], dim=0)
    assert n_pos == X_pos.size(0)
    assert n_neg == X_neg.size(0)
    Xs.append(torch.cat([X_pos, X_neg], dim=0).to(device) * scaler)
    ys.append(torch.cat([torch.ones(n_pos).long(), -torch.ones(n_neg).long()])[:, None].to(device))

    # Task 3:
    # +1: rectangle with x[2,10]y[2,9]
    # -1: circle with radius 4 center [-10,-8]
    n_pos = add + (torch.rand(size=(data_per_task-add,)) < 0.5).sum().item()
    n_neg = data_per_task - n_pos
    X_pos = get_rectangular_data(n_pos-add, 2, 10, 2, 9)
    if include_supp:
        X_pos = torch.cat([X_pos, torch.tensor([2., 2.])[None, :]], dim=0)
    if make_less_separable:
        X_pos -= torch.tensor([[4.,4.]])
    X_neg = get_round_data(n_neg, 4, -10, -8)
    assert n_pos == X_pos.size(0)
    assert n_neg == X_neg.size(0)
    Xs.append(torch.cat([X_pos, X_neg], dim=0).to(device) * scaler)
    ys.append(torch.cat([torch.ones(n_pos).long(), -torch.ones(n_neg).long()])[:, None].to(device))

    # total dataset
    X = torch.cat(Xs, dim=0)
    y = torch.cat(ys, dim=0)
    return Xs, ys, X, y


def get_3d_toy_tasks(device='cpu'):
    X1 = torch.Tensor([[1, 1, 0], [1, -2, 1]])
    y1 = torch.ones(X1.size(0)).view((-1,1))

    X2 = torch.Tensor([[1, 0, 1],[1, 1, -2]])
    y2 = torch.ones(X2.size(0)).view((-1,1))

    XT = torch.cat([X1, X2])
    yT = torch.cat([y1, y2])
    return [X1, X2], [y1, y2], XT, yT


def get_model(model_type, hidden_dims=None, net_act=None, bias=None):
    if model_type=='ReLU':
        hidden_dims = [2, 500, 1] if hidden_dims is None else hidden_dims 
        net_act = 'relu' if net_act is None else net_act
        bias = True if bias is None else bias
        model = MLP(hidden_dims=hidden_dims, net_act=net_act, bias=bias)
    elif model_type=='Linear':
        hidden_dims = [2, 1] if hidden_dims is None else hidden_dims 
        net_act = None
        bias = False if bias is None else bias
        model = MLP(hidden_dims=hidden_dims, net_act=net_act, bias=bias)
    
    return model


def get_sine_angle(model1, model2):
    def get_param(x):
        if isinstance(x, MLP):
            return x.get_param().numpy()
        elif isinstance(x, torch.Tensor):
            return x.detach().cpu().numpy()
    w1 = get_param(model1)
    w2 = get_param(model2)

    # distance.cosine(w1, w2) is in fact "1-cos(\theta(w1, w2))".
    cosine_squared = (1 - distance.cosine(w1, w2)) ** 2  
    sine_angle = math.sqrt(1 - cosine_squared)

    return sine_angle


def get_orthogonal_line(vector):
    v = vector.view(-1)
    assert v.size(0) == 2
    x = torch.tensor([v[1], -v[1]]) * 100
    y = torch.tensor([-v[0], v[0]]) * 100
    return x, y


def get_predictions_xyz(model:MLP, xmin, xmax, ymin, ymax, n_per_dim=100, device='cpu'):
    x, y = torch.meshgrid([
        torch.linspace(xmin, xmax, n_per_dim),
        torch.linspace(ymin, ymax, n_per_dim)], indexing='ij')
    plane = torch.stack((x.ravel(), y.ravel()), dim=1)
    if isinstance(model, MLP):
        z = model(plane.to(device)).detach().cpu().view(x.size())
    else:
        z = model.predict_proba(plane.cpu().numpy())[:,1:].reshape(x.numpy().shape) - 0.5
    return x, y, z
