import random
from collections import defaultdict
from typing import OrderedDict
from functools import partial
import argparse
from numpy.linalg import norm

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import xlrd

import sys

# import packages from parent directory
sys.path.append('..')
from optimizer.DAdaST import TiAda, TiAda_Adam, TiAda_wo_max

from tensorboard_logger import Logger

# Argument
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--n_iter', type=int, default=3000, help='number of gradient calls')
parser.add_argument('--lr_y', type=float, default=0.01, help='learning rate of y')
parser.add_argument('--r', type=float, default=1, help='stepsize of y / stepsize of x')
parser.add_argument('--init_x', type=float, default=None, help='init value of x')
parser.add_argument('--init_y', type=float, default=None, help='init value of y')
parser.add_argument('--grad_noise_y', type=float, default=0.1, help='gradient noise variance')
parser.add_argument('--grad_noise_x', type=float, default=0.1, help='gradient noise variance')

parser.add_argument('--func', type=str, default='quadratic', help='function name')
parser.add_argument('--L_ave', type=float, default=2, help='parameter for the test function')

parser.add_argument('--optim', type=str, default='TiAda', help='optimizer')
parser.add_argument('--alpha', type=float, default=0.6, help='parameter for TiAda')
parser.add_argument('--beta', type=float, default=0.4, help='parameter for TiAda')
parser.add_argument('--tracking', type=int, default=0, help='gradient noise variance')

args = parser.parse_args()

def excel_to_matrix(path):
    table = xlrd.open_workbook(path).sheets()[0]
    row = table.nrows
    col = table.ncols
    datamatrix = np.zeros((row, col))
    for x in range(col):
        cols = np.matrix(table.col_values(x))
        datamatrix[:, x] = cols
    return datamatrix



# Set precision to 64
torch.set_default_dtype(torch.float64)

# Reproducibility
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

n = 50  # number of nodes

L = torch.linspace(1.5 + 1/(2*n), 2.5-1/(2*n), n).view(-1, 1)


L_ave = 2


df = pd.read_excel(f"graphs/my_graph_exp_graph_{n}.xlsx", engine="openpyxl", header=None, index_col=None)

matrix = df.to_numpy()


J = np.ones((n, n)) / n
difference = J - matrix
eigenvalues_diff = np.linalg.eigvals(difference)
spectral_radius_diff = np.max(np.abs(eigenvalues_diff))
print("Spectral Radius of J - W:", spectral_radius_diff)


W = torch.tensor(matrix)


# Different functions
functions = OrderedDict()


functions["quadratic"] = {
    "func":
        lambda x, y: -1 / 2 * (y ** 2) + L * x * y - (L ** 2 / 2) * (x ** 2) - 2*L*x + L*y,
    "grad_x":
        lambda x, y: L*y -(L**2)*x - 2*L,
    "grad_y":
        lambda x, y: -y + L*x + L,
}

functions["McCormick"] = {
    "func":
        lambda x, y: torch.sin(x[0] + x[1]) + (x[0] - x[1]) ** 2 - 1.5 * x[0] + 2.5 * x[1] + 1 + y[0] * x[0] + y[1] * x[
            1] \
                     - 0.5 * (y[0] ** 2 + y[1] ** 2),
}

optimizers = OrderedDict()
if args.func == 'McCormick':
    # Adam is extremely unstable on McCormick functions, so we need a large eps
    eps = 0.8
    optimizers["Adam"] = partial(torch.optim.Adam, eps=eps)
    optimizers["AMSGrad"] = partial(torch.optim.Adam, amsgrad=True, eps=eps)
else:
    eps = 1e-8
    optimizers["Adam"] = partial(torch.optim.Adam, eps=eps)
    optimizers["AMSGrad"] = partial(torch.optim.Adam, amsgrad=True, eps=eps)

optimizers["AdaGrad"] = torch.optim.Adagrad
optimizers["GDA"] = torch.optim.SGD

# TiAda
optimizers["TiAda"] = TiAda
optimizers["TiAda_Adam"] = TiAda_Adam
optimizers["TiAda_wo_max"] = TiAda_wo_max



n_iter = args.n_iter
ratio = args.r

print(f"Function: {args.func}")
print(f"Optimizer: {args.optim}")
fun = functions[args.func]["func"]
grad_x = functions[args.func]["grad_x"]
grad_y = functions[args.func]["grad_y"]

if args.func == "McCormick":
    dim = 2
else:
    dim = 1

# Tensorboard
filename = f"./logs/{args.optim}_"
if args.func == 'quadratic':
    filename += f"L{L_ave}"
else:
    filename += f"{args.func}"
filename += f"_r_{ratio}_lry_{args.lr_y}"
if 'TiAda' in args.optim:
    filename += f"_a_{args.alpha}_b_{args.beta}"
if args.grad_noise_x != 0:
    filename += f"_noisex_{args.grad_noise_x}"
if args.grad_noise_y != 0:
    filename += f"_noisey_{args.grad_noise_y}"
if args.tracking == 1:
    filename += f"_tracking"

logger = Logger(filename)

# learning rate
lr_y = args.lr_y
lr_x = lr_y / ratio

scale_y = np.ones((n, 1))
scale_x = np.ones((n, 1))


if args.init_x is None:
    init_x = torch.randn(n, dim)
else:
    init_x = torch.Tensor([args.init_x]*np.ones((n, dim)))
if args.init_y is None:
    init_y = torch.randn(n, dim)
else:
    init_y = torch.Tensor([args.init_y]*np.ones((n, dim)))
if args.func != 'bilinear':
    print(f"init x: {init_x}, init y: {init_y}")

x = torch.nn.parameter.Parameter(init_x.clone())
y = torch.nn.parameter.Parameter(init_y.clone())

if "NeAda" in args.optim:
    optim_name = args.optim[6:]
else:
    optim_name = args.optim

if args.optim == 'TiAda':
    optim_y = TiAda([y], lr=lr_y, alpha=args.beta)
    optim_x = TiAda([x], opponent_optim=optim_y, lr=lr_x, alpha=args.alpha)
elif args.optim == 'TiAda_Adam':
    optim_y = TiAda_Adam([y], lr=lr_y, alpha=args.beta, eps=eps)
    optim_x = TiAda_Adam([x], opponent_optim=optim_y, lr=lr_x,
                         alpha=args.alpha, eps=eps)
elif args.optim == 'TiAda_wo_max':
    optim_x = TiAda_wo_max([x], lr=lr_x, alpha=args.alpha)
    optim_y = TiAda_wo_max([y], lr=lr_y, alpha=args.beta)
else:
    optim = optimizers[optim_name]
    optim_x = optim([x], lr=lr_x)
    optim_y = optim([y], lr=lr_y)

i = 0
outer_loop_count = 0
while i < n_iter:
    if "NeAda" in args.optim:
        # inner loop
        required_err = 1 / (outer_loop_count + 1)
        inner_step = 0
        inner_err = required_err + 1  # ensure execute at least one step
        stop_constant = 1  # stop when number of steps >= stop_constant * outer_loop_count
        if args.func == 'quadratic':
            # Stop earlier in quadratic case
            stop_constant = 0.1
        while inner_err > required_err and i < n_iter and inner_step < stop_constant * outer_loop_count:
            inner_step += 1
            l = fun(x, y)
            y_grad = grad_y(x, y)
            # stocastic gradient
            y_grad = y_grad + args.grad_noise_y * torch.randn(n, dim)
            y = y + lr_y * torch.tensor(scale_y ** (-0.5)) * y_grad
            y_grad_norm = y_grad.detach().numpy() ** 2
            scale_y += y_grad_norm * y_grad_norm

            inner_err = torch.norm(y_grad) ** 2
            i += 1

        if i == n_iter:
            break


        l = fun(x, y)
        x_grad = grad_x(x, y)
        # stocastic gradient
        x_grad = x_grad + args.grad_noise_x * torch.randn(n, dim)
        x = x - lr_x*torch.tensor(scale_x**(-0.5))*x_grad
        # record the deterministic gradient norm
        x_grad_norm = x_grad.detach().numpy()**2

        i += 1
        outer_loop_count += 1


        y = torch.matmul(W, y)  # y communication
        x = torch.matmul(W, x)  # x communication

        # calculate the norm of gradient on x_ave
        x_ave = torch.mean(x)
        y_ave = torch.mean(y)
        x_grad_ave = grad_x(x_ave, y_ave)
        x_grad_ave_norm = (torch.mean(x_grad_ave).item())**2
        scale_x += x_grad_norm * x_grad_norm
        zeta_v = norm((np.mean(scale_x) ** 0.5) / (scale_x ** 0.5) - 1) / n

        logger.scalar_summary('x_grad_ave', step=i, value=x_grad_ave_norm)
        logger.scalar_summary('scalar_x', step=i, value=np.mean(scale_x))
        logger.scalar_summary('zeta_v', step=i, value=zeta_v)
        if i == n_iter-2:
            print('Value_x_y:\n x', x, '\n y', y)

    else:  # other single-loop optimizers

        l = fun(x, y)
        x.grad = grad_x(x, y)
        y.grad = grad_y(x, y)
        # stocastic gradient
        y.grad = y.grad + args.grad_noise_y * torch.randn(n, dim)
        x.grad = x.grad + args.grad_noise_x * torch.randn(n, dim)

        i += 2

        y_grad_norm = y.grad.detach().numpy()**2
        x_grad_norm = x.grad.detach().numpy()**2

        scale_y += y_grad_norm * y_grad_norm
        scale_x += x_grad_norm * x_grad_norm
        if args.alpha > args.beta:
            scale_x = np.maximum(scale_x, scale_y)

        logger.scalar_summary('scalar_x', step=i, value=np.mean(scale_x))

        # calculate the norm of gradient on x_ave
        x_ave = torch.mean(x)
        y_ave = torch.mean(y)
        x_grad_ave = grad_x(x_ave, y_ave)
        x_grad_ave_norm = (torch.mean(x_grad_ave).item())**2

        logger.scalar_summary('x_grad_ave', step=i, value=x_grad_ave_norm)

        lr_low = 0
        if args.alpha != 0.5:
            lr_low = 0
        y = y + (lr_low + lr_y*torch.tensor(scale_y**(-args.beta)))*y.grad
        x = x - (lr_low + lr_x*torch.tensor(scale_x**(-args.alpha)))*x.grad

        y = torch.matmul(W, y)  # y communication
        x = torch.matmul(W, x)  # x communication
        if args.tracking == 1:
            scale_y = np.dot(W, scale_y)  # V communication
            scale_x = np.dot(W, scale_x)  # U communication

        zeta_u = norm((np.mean(scale_y)**args.beta) / (scale_y**args.beta) - 1)/n
        zeta_v = norm((np.mean(scale_x)**args.alpha) / (scale_x**args.alpha) - 1)/n
        logger.scalar_summary('zeta_v', step=i, value=zeta_v)

        if i == n_iter-2:
            print('Value_x_y:\n x', x, '\n y', y)
        # if dim == 1:
        #     logger.scalar_summary('x', step=i, value=x.item())
        #     logger.scalar_summary('y', step=i, value=y.item())
    if x_grad_ave_norm > 1e4:
        break
