import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
matplotlib.rcParams['text.usetex'] = True
from code.optim.aligned import ProcrustesSolver
from code.optim.pcgrad import RandomProjectionSolver
from code.optim.mgda import MinNormSolver


class Example2D(torch.nn.Module):
    def __init__(self, alpha, beta) -> None:
        super().__init__()
        self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=True)
        self.beta = torch.nn.Parameter(torch.tensor(beta), requires_grad=True)

    @staticmethod
    def loss(x, y, total=False):
        loss_left = 20 * torch.log(
            torch.max(torch.abs(0.5 * x + torch.tanh(y)), torch.tensor(0.000005))
        )
        loss_right = 25 * torch.log(
            torch.max(torch.abs(0.5 * x - torch.tanh(y) + 2), torch.tensor(0.000005))
        )

        if total: return loss_left + loss_right

        return loss_left, loss_right

    def forward(self):
        return self.loss(self.alpha, self.beta)


def collect_grad(module):
    return torch.cat(
        [
            p.grad.flatten().detach().data.clone()
            for p in module.parameters()
            if p.grad is not None
        ]
    )


def update_grad(module, grad):
    offset = 0
    for p in module.parameters():
        if p.grad is not None:
            p.grad.data = grad[offset:offset + p.grad.shape.numel()].view(p.grad.shape)
            offset += p.grad.shape.numel()


def aligned_gradient_descent(num_iters=15000, init_point=(0.5, -3.0)):

    def aligned_balancer_step(landscape):
        grads = []
        losses = landscape()
        for loss in losses:
            landscape.zero_grad()
            loss.backward()
            grads.append(collect_grad(landscape))
        grads = torch.stack(grads, dim=1)

        grads, _, singulars = ProcrustesSolver.apply(grads.unsqueeze(0))

        # In this case we need to control the length because
        # both vectros does not change direction and length
        if len(singulars) == 1:
            grads = grads / (torch.norm(grads[0].sum(-1)) + 1e-5)

        update_grad(landscape, grads[0].sum(dim=-1))

        return [loss.item() for loss in losses]


    path = []
    landscape = Example2D(*init_point)

    optim = torch.optim.Adam(landscape.parameters(), 0.01)

    for it in range(num_iters):
        l, r = aligned_balancer_step(landscape)
        path.append([
            l+r,
            landscape.alpha.clone().detach(),
            landscape.beta.clone().detach()
        ])

        optim.step()

    last_point = [
        sum(landscape()).item(),
        landscape.alpha.clone().detach(),
        landscape.beta.clone().detach()
    ]

    path.append(last_point)
    path = np.array(path)

    return path


def random_projected_gradient_descent(num_iters=15000, init_point=(0.5, -3.0)):

    def pcgrad_balancer_step(landscape):
        grads = []
        losses = landscape()
        for loss in losses:
            landscape.zero_grad()
            loss.backward()
            grads.append(collect_grad(landscape))
        grads = torch.stack(grads, dim=0)

        proj_grads = RandomProjectionSolver.apply(grads)

        update_grad(landscape, proj_grads.sum(0))

        return [loss.item() for loss in losses]

    path = []
    landscape = Example2D(*init_point)
    optim = torch.optim.Adam(landscape.parameters(), 0.001)

    for it in range(num_iters):
        l, r = pcgrad_balancer_step(landscape)
        path.append([
            l+r,
            landscape.alpha.clone().detach(),
            landscape.beta.clone().detach()
        ])

        optim.step()

    last_point = [
        sum(landscape()).item(),
        landscape.alpha.clone().detach(),
        landscape.beta.clone().detach()
    ]

    path.append(last_point)
    path = np.array(path)

    return path


def pareto_optimal_gradient_descent(num_iters=15000, init_point=(0.5, -3.0)):

    def mgda_balancer_step(landscape):
        grads = []
        losses = landscape()
        for loss in losses:
            landscape.zero_grad()
            loss.backward()
            grads.append(collect_grad(landscape))
        grads = torch.stack(grads, dim=0)

        scales, _ = MinNormSolver.apply(grads)
        grads = grads * scales.view(-1, 1)

        update_grad(landscape, grads.sum(0))

        return [loss.item() for loss in losses]


    path = []
    landscape = Example2D(*init_point)
    optim = torch.optim.Adam(landscape.parameters(), 0.01)

    for it in range(num_iters):
        l, r = mgda_balancer_step(landscape)
        path.append([
            l+r,
            landscape.alpha.clone().detach(),
            landscape.beta.clone().detach()
        ])

        optim.step()

    last_point = [
        sum(landscape()).item(),
        landscape.alpha.clone().detach(),
        landscape.beta.clone().detach()
    ]

    path.append(last_point)
    path = np.array(path)

    return path


def adam_gradient_descent(num_iters=15000, init_point=(0.5, -3.0)):

    path = []
    landscape = Example2D(*init_point)
    optim = torch.optim.Adam(landscape.parameters(), 0.001)

    for it in range(num_iters):
        l, r = landscape()
        path.append([
            (l+r).item(),
            landscape.alpha.clone().detach(),
            landscape.beta.clone().detach()
        ])
        loss = l + r
        landscape.zero_grad()
        loss.backward()
        optim.step()

    last_point = [
        sum(landscape()).item(),
        landscape.alpha.clone().detach(),
        landscape.beta.clone().detach()
    ]

    path.append(last_point)
    path = np.array(path)

    return path



if __name__ == "__main__":
    X = np.arange(-9, 5, 0.25)
    Y = np.arange(-5, 6, 0.25)
    X, Y = np.meshgrid(X, Y)
    lev = [-500 + i*10 for i in range(54)]

    X, Y = torch.as_tensor(X).float(), torch.as_tensor(Y).float()
    Z = Example2D.loss(X, Y, total=True).numpy()
    Z_1, Z_2 = Example2D.loss(X, Y)
    
    fig, ax1 = plt.subplots()
    fig.set_figwidth(4)
    fig.set_figheight(4)
    ax1.contourf(X, Y, Z_1, levels=lev, cmap='gist_heat')
    plt.tight_layout()
    ax1.set_xlim((-4, 2.1))
    ax1.set_xlabel("$\\theta_1$", fontsize=20)
    ax1.set_ylabel("$\\theta_2$", fontsize=20)
    plt.savefig('Z1_contour.pdf', dpi=120, format='pdf', bbox_inches='tight')

    fig, ax2 = plt.subplots()
    fig.set_figwidth(4)
    fig.set_figheight(4)
    ax2.contourf(X, Y, Z_2, levels=lev, cmap='gist_heat')
    plt.tight_layout()
    ax2.set_xlim((-4, 2.1))
    ax2.set_xlabel("$\\theta_1$", fontsize=20)
    ax2.set_ylabel("$\\theta_2$", fontsize=20)
    plt.savefig('Z2_contour.pdf', dpi=120, format='pdf', bbox_inches='tight')
    
    '''
    #fig = plt.figure(figsize=(6, 2))
    fig = plt.figure()
    ax = plt.axes(projection='3d')

    surf = ax.plot_surface(X, Y, Z, cmap='gist_heat', edgecolor='none')
    #ax.contourf(X, Y, Z, levels=lev, cmap='binary')
    #surf._facecolors2d = surf._facecolors3d
    #surf._edgecolors2d = surf._edgecolors3d
    ax.xaxis.set_major_locator(plt.MaxNLocator(2))
    ax.yaxis.set_major_locator(plt.MaxNLocator(2))
    ax.zaxis.set_major_locator(plt.MaxNLocator(2))
    ax.set_xlabel("$\\theta_1$", fontsize=20)
    ax.set_ylabel("$\\theta_2$", fontsize=20)
    ax.tick_params(axis='both', which='major', labelsize=15)
    plt.tight_layout()
    plt.savefig('toy3d.pdf', dpi=1200, format='pdf', bbox_inches='tight')
    #plt.show()
    sys.exit()
    '''

    num_iters = 5000
    aligned_gd_path = aligned_gradient_descent(num_iters=num_iters)
    pcgrad_gd_path = random_projected_gradient_descent(num_iters=num_iters)
    mgda_gd_path = pareto_optimal_gradient_descent(num_iters=num_iters)
    adam_gd_path = adam_gradient_descent(num_iters=num_iters)
    
    fig, ax = plt.subplots()
    fig.set_figwidth(7)    
    fig.set_figheight(4)  
    
    ax.plot(mgda_gd_path[:, 1], mgda_gd_path[:, 2], color="blue", label="MGDA,\t f($\\theta_1, \\theta_2$)={:.2f}".format(mgda_gd_path[-1, 0]))
    ax.plot(mgda_gd_path[-1, 1], mgda_gd_path[-1, 2], color="blue", marker="*", markersize=19)
    
    ax.plot(aligned_gd_path[:, 1], aligned_gd_path[:, 2], color="red", label="$\\theta$-aligned, f($\\theta_1, \\theta_2$)={:.2f}".format(aligned_gd_path[-1, 0]))
    ax.plot(aligned_gd_path[-1, 1], aligned_gd_path[-1, 2], color="red", marker="*", markersize=19)
    ax.plot(pcgrad_gd_path[:, 1], pcgrad_gd_path[:, 2], color="green", label="PCGrad, f($\\theta_1, \\theta_2$)={:.2f}".format(pcgrad_gd_path[-1, 0]))
    ax.plot(pcgrad_gd_path[-1, 1], pcgrad_gd_path[-1, 2], color="green", marker="*", markersize=19)
    ax.plot(adam_gd_path[:, 1], adam_gd_path[:, 2], color="magenta", label="Adam,\t f($\\theta_1, \\theta_2$)={:.2f}".format(adam_gd_path[-1, 0]))
    ax.plot(adam_gd_path[-1, 1], adam_gd_path[-1, 2], color="magenta", marker="*", markersize=19)
    ax.legend(prop={"size": 12}, loc='upper right') 
    ax.contourf(X, Y, Z, cmap='gist_heat', levels=lev)
    ax.set_xlabel("$\\theta_1$", fontsize=20)
    ax.set_ylabel("$\\theta_2$", fontsize=20)
    ax.tick_params(axis='both', which='major', labelsize=15)
    ax.set_xlim((-4, 2.1))
    plt.savefig('toy2d_tgd_m.pdf', dpi=1200, format='pdf', bbox_inches='tight')
    
    print('Done')
