import torch
import random
import numpy as np
import cvxpy as cp
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from collections import defaultdict
from torch.utils.data import TensorDataset, DataLoader
from auto_LiRPA import PerturbationLpNorm, BoundedParameter, BoundedModule

seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
from scipy.stats import multivariate_normal as mvn
from scipy.stats import norm, uniform

count=1
def visualize_predictions(x_t, l_t, preds):
    global count
    for i in range(l_t.shape[0]):
        plt.scatter(
            x=x_t[i, 0],
            y=x_t[i, 1],
            marker="x" if l_t[i] == 1 else "o",
            c="r" if preds[i] == 1 else "b",
        )
    plt.savefig(f'toy_ours_task_{count}.pdf')
    count+=1
    # plt.show()

def visualize_task(x_t, l_t):
    global count
    plt.clf()
    # plt.axis('equal')
    # plt.xlim((-1.1, 1.1))#
    plt.ylim((-1.2, 1.2))
    for i in range(l_t.shape[0]):
        plt.scatter(
            x=x_t[i, 0],
            y=x_t[i, 1],
            marker="x" if l_t[i] == 1 else "o",
            c="black",
        )
    plt.savefig(f'toy_task_{count}.pdf')
    count+=1
def update_layer_weights(layer, grad_w, grad_b, lr):
    layer.weight.data = (
        layer.weight.data
        - torch.from_numpy(grad_w.reshape(layer.weight.shape)).float() * lr
    )
    layer.bias.data = (
        layer.bias.data
        - torch.from_numpy(grad_b.reshape(layer.bias.shape)).float() * lr
    )


def update_weights_with_constraint(model, buffer, buffer_labels):
    g_w1 = model.fc1.weight.grad.view(-1).detach().numpy()
    g_b1 = model.fc1.bias.grad.view(-1).detach().numpy()
    g_w2 = model.fc2.weight.grad.view(-1).detach().numpy()
    g_b2 = model.fc2.bias.grad.view(-1).detach().numpy()

    buf_out = model(buffer).detach()
    model.bound_parameters(lr)
    dummy_input = torch.randn(2, 1)
    model_bounded = BoundedModule(
        model,
        dummy_input,
        bound_opts={
            "relu": False,
            "loss_fusion": True,
            "sparse_intermediate_bounds": False,
            "sparse_conv_intermediate_bounds": False,
            "sparse_intermediate_bounds_with_ibp": False,
        },
    )
    output_name = model_bounded.output_name[0]

    # out = model_bounded(x)
    required_A = defaultdict(set)
    required_A[output_name].update(model_bounded.input_name[0:])
    _, _, A_dict = model_bounded.compute_bounds(
        x=(buffer,),
        method="crown-ibp",
        return_A=True,
        needed_A_dict=required_A,
        need_A_only=True,
    )
    A_dict = A_dict[output_name]

    mat_list_u = []
    mat_list_l = []
    b_list_u = []
    b_list_l = []
    for coeffs in A_dict.values():
        lA = coeffs["lA"].view(buffer.size(0), 1, -1)
        uA = coeffs["uA"].view(buffer.size(0), 1, -1)
        lb = coeffs["lbias"].view(buffer.size(0), 1, -1)
        ub = coeffs["ubias"].view(buffer.size(0), 1, -1)
        mat_list_u.append(uA)
        mat_list_l.append(lA)
        b_list_l.append(lb)
        b_list_u.append(ub)

    w1 = model.fc1.weight.view(-1).detach().numpy()
    b1 = model.fc1.bias.view(-1).detach().numpy()
    w2 = model.fc2.weight.view(-1).detach().numpy()
    b2 = model.fc2.bias.view(-1).detach().numpy()
    d_w1 = cp.Variable(w1.shape)
    d_b1 = cp.Variable(b1.shape)
    d_w2 = cp.Variable(w2.shape)
    d_b2 = cp.Variable(b2.shape)

    consts = [
        d_w1 <= abs(g_w1),
        d_w1 >= -abs(g_w1),
        d_b1 <= abs(g_b1),
        d_b1 >= -abs(g_b1),
        d_w2 <= abs(g_w2),
        d_w2 >= -abs(g_w2),
        d_b2 <= abs(g_b2),
        d_b2 >= -abs(g_b2),
    ]
    for j in range(buffer_labels.shape[0]):
        if torch.sign(buf_out[j]) * buffer_labels[j] > 0:
            if buffer_labels[j] > 0:
                As = mat_list_l
                b = b_list_l
            else:
                As = mat_list_u
                b = b_list_u
            expr = As[0][j] @ (w1 - lr * d_w1)
            expr += As[1][j] @ (b1 - lr * d_b1)
            expr += As[2][j] @ (w2 - lr * d_w2)
            expr += As[3][j] @ (b2 - lr * d_b2)
            expr += b[0][j]
            consts.append(0 <= expr if buffer_labels[j] > 0 else expr <= 0)
    prob = dist_prob([g_w1, g_b1, g_w2, g_b2], [d_w1, d_b1, d_w2, d_b2], consts, mode=2)
    prob.solve()

    update_layer_weights(model.fc1, d_w1.value, d_b1.value, lr)
    update_layer_weights(model.fc2, d_w2.value, d_b2.value, lr)

def dist_prob(grads, weights, consts, mode=1):
    if mode == 1 or mode == 2 or mode == 'inf':
        obj = sum(cp.norm(grads[i] - weights[i], mode) for i in range(len(grads)))
    elif mode == 'dot':
        obj = -sum(grads[i].T @ weights[i] for i in range(len(grads)))
    else:
        raise ValueError('Dist mode {} not supported'.format(mode))
    return cp.Problem(cp.Minimize(obj), consts)

def train_task(
    model,
    dataset,
    epochs,
    criterion,
    constrained=False,
    buffer=None,
    buffer_labels=None,
):
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
    buf_x = None
    buf_y = None
    for _ in range(epochs):
        for data in dataloader:
            opt.zero_grad()
            x = data[0].float()
            l = data[1].float()
            if buf_x is None:
                buf_x = x
                buf_y = l
            loss = criterion(
                model(x), target=l.unsqueeze(1)
            )  # + lam * torch.norm(w)**2
            loss.backward()
            if not constrained:
                opt.step()
                continue
            update_weights_with_constraint(model, buffer, buffer_labels)

    preds = torch.where(F.sigmoid(model(x_t)) > 0.5, 1, 0).view(-1)
    print("Accuracy: ", (preds == l_t).float().mean().item())
    # visualize_task(np.concatenate((x_t[0:200], x_t[400:600])), np.concatenate((l_t[0:200], l_t[400:600])))
    visualize_task(x_t[x_t[:, 1] >=0], l_t[x_t[:, 1] >=0])
    visualize_task(x_t[x_t[:, 1] < 0], l_t[x_t[:, 1] < 0])
    visualize_predictions(x_t, l_t, preds)
    # print(buf_x.size(), buf_y.size())
    return buf_x, buf_y


class mlp_2layer_weight_perturb(nn.Module):
    def __init__(self, in_ch=1, in_dim=2, width=1):
        super(mlp_2layer_weight_perturb, self).__init__()
        self.fc1 = nn.Linear(in_ch * in_dim, 16 * width)
        self.fc2 = nn.Linear(16 * width, 1)
        self.params = [self.fc1, self.fc2]

    def bound_parameters(self, lr, pert_weight=True, pert_bias=True, norm=float("inf")):
        for p in self.params:
            eps = p.weight.grad
            ptb = PerturbationLpNorm(norm=norm, eps=torch.linalg.norm(eps) * lr)
            p.weight = BoundedParameter(p.weight.data, ptb)
            eps = p.bias.grad
            ptb = PerturbationLpNorm(norm=norm, eps=torch.linalg.norm(eps) * lr)
            p.bias = BoundedParameter(p.bias.data, ptb)

    def forward(self, x):
        x = x.view(-1, 2)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


def plot_decision_boundry(model, points, labels , buffer=None, buffer_labels=None):
    print('whitihn ploting')
    x_range = np.linspace(-1.5, 1.5, 200)
    y_range = np.linspace(-1.5, 1.5, 200)
    xx, yy = np.meshgrid(x_range, y_range)
    grid = np.c_[xx.ravel(), yy.ravel()]
    grid_tensor = torch.tensor(grid, dtype=torch.float32)

    with torch.no_grad():
        logits = model(grid_tensor)
        predictions = torch.sigmoid(logits).numpy()

    predictions = predictions.reshape(xx.shape)


    plt.figure(figsize=(8, 6))
    plt.contourf(xx, yy, predictions, levels=50, cmap="RdBu_r", alpha=0.8)
    plt.contour(xx, yy, predictions, levels=[0.5], colors="black", linewidths=1)
    plt.scatter(
            points[labels == 0, 0],
            points[labels == 0, 1],
            c="black",
            marker = 'o'
        )
    plt.scatter(
        points[labels == 1, 0],
        points[labels == 1, 1],
        c="black",
        marker = 'x'
    )
    if buffer is not None:
        plt.scatter(
            buffer[buffer_labels == 0, 0],
            buffer[buffer_labels == 0, 1],
            c="red",
            marker = 'o'
        )
        plt.scatter(
            buffer[buffer_labels == 1, 0],
            buffer[buffer_labels == 1, 1],
            c="red",
            marker = 'x'
        )
    plt.savefig('toy_ours_final.pdf', format='pdf')
    # plt.show()

if __name__ == "__main__":
    m1 = [0, 0]
    m2 = 1
    d1 = mvn(m1, cov=np.ones(2) * .05)
    d2_r = norm(m2, scale=0.05)
    d2_a = uniform(0, 2 * np.pi)


    data_size = 1000
    x1 = torch.tensor(d1.rvs(size=data_size * 2))
    x1_1 = x1[x1[:, 1] >= 0]
    x2_1 = x1[x1[:, 1] < 0]

    x2_r = torch.tensor(d2_r.rvs(size=data_size *2))
    x2_a = torch.tensor(d2_a.rvs(size=data_size *2))
    x2 = torch.stack([x2_r * torch.cos(x2_a), x2_r * torch.sin(x2_a)], dim=1)
    x1_2 = x2[x2[:, 1] >= 0]
    x2_2 = x2[x2[:, 1] < 0] 
    


    dataset = [
        TensorDataset(
            torch.cat([x1_1, x1_2]),
            torch.cat([torch.zeros(x1_1.size(0)), torch.ones(x1_2.size(0))]),
        ),
        TensorDataset(
            torch.cat([x2_1, x2_2]),
            torch.cat([torch.zeros(x2_1.size(0)), torch.ones(x2_2.size(0))]),
        ),
    ]

    print(dataset[0][0])

    x1_t = torch.tensor(d1.rvs(size=data_size//5))
    x1_1_t = x1_t[x1_t[:, 1] >= 0]
    x2_1_t = x1_t[x1_t[:, 1] < 0]
    l_1_t = np.zeros(data_size//5)

    x2_t_r = torch.tensor(d2_r.rvs(size=data_size//5))
    x2_t_a = torch.tensor(d2_a.rvs(size=data_size//5))
    x2_t = torch.stack([x2_t_r * torch.cos(x2_t_a), x2_t_r * torch.sin(x2_t_a)], dim=1)
    x1_2_t = x2_t[x2_t[:, 1] >= 0]
    x2_2_t = x2_t[x2_t[:, 1] < 0] 
    l_2_t = np.ones(data_size//5)
    x_t = torch.from_numpy(
        np.concatenate([x1_1_t, x2_1_t, x1_2_t, x2_2_t], axis=0)
    ).float()
    l_t = torch.from_numpy(np.concatenate([l_1_t, l_2_t], axis=0))

    criterion = nn.BCEWithLogitsLoss()
    model = mlp_2layer_weight_perturb(width=2)

    weight_decay = 0
    lr = 0.1
    epochs = 10
    opt = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)
    batch_size = 64
    # print(l_t)
    print("Training task 1")
    buffer, buffer_labels = train_task(model, dataset[0], epochs, criterion)

    print("\nTraining task 2")
    train_task(
        model,
        dataset[1],
        epochs,
        criterion,
        constrained=True,
        buffer=buffer,
        buffer_labels=buffer_labels,
    )
    plot_decision_boundry(model, x_t,l_t, buffer, buffer_labels)
