import torch
import torch.nn as nn
from torch import optim
import numpy as np

from matplotlib import pyplot as plt
from sklearn.decomposition import PCA


import os
import sys
import errno
import shutil
import os.path as osp


def _build_dataset(n, d, n_informative=None, random_seed=2024):
    if n_informative is None:
        n_informative = d
    np.random.seed(random_seed)
    H = np.eye(d)
    w_star = np.random.randn(d)
    w_star[n_informative:] = 0

    # generate dataset
    X = np.random.multivariate_normal(np.zeros(d), H, n)
    y = X @ w_star
    tensor = lambda x: torch.from_numpy(x).to(torch.float32)

    return tensor(X), tensor(y), tensor(w_star)


def l2distance(x, y):
    return (x.reshape(-1, 1) - y.reshape(-1, 1)).detach().norm(2).numpy()


def mse(y_hat, y):
    return nn.MSELoss()(y_hat.reshape(-1, 1), y.reshape(-1, 1)) / 2


def jacobian(wpm, nabla_w, p=True):
    return 2 * wpm * nabla_w if p else -2 * wpm * nabla_w


def hessian(wpm, nabla_w, X, p=True):
    if p:
        diag_grad = 2 * torch.diag(nabla_w.reshape(-1))
    else:
        diag_grad = -2 * torch.diag(nabla_w.reshape(-1))
    matrix_term = torch.zeros_like(diag_grad)
    n, d = X.shape
    for x in X:
        temp = wpm.reshape(-1, 1) * x.reshape(-1, 1)
        matrix_term += temp @ temp.reshape(1, -1)
    return diag_grad + 4 * matrix_term / n

def grad_hessian_product(wpm, nabla_w, X, p=True):
    jac = jacobian(wpm, nabla_w, p)
    hes = hessian(wpm, nabla_w, X, p)
    return hes @ jac


def HBFReg(wpm, nabla_w, X, p, mu, eta):
    reg_coef = eta * (1 + mu) / (2 * (1 - mu) ** 3)
    return reg_coef * grad_hessian_product(wpm, nabla_w, X, p)


def grad(wpm, nabla_w, X, p=True, hbf=False, mu=0.9, eta=1e-3):
    wpm_grad = jacobian(wpm, nabla_w, p) / (1 - mu)
    if hbf:
        hbf_grad = HBFReg(wpm, nabla_w, X, p, mu, eta)
        wpm_grad += hbf_grad
    return wpm_grad


def train(
    net,
    X,
    y,
    w_star,
    eta,
    num_it,
    eta_euler=1e-4,
    mu=0.9,
    hbf=False,
    return_trajectory=True,
    return_loss=False,
    factor=10,
):
    # factor = eta / eta_euler
    if return_trajectory:
        distance_history = [net.w.detach()]
    else:
        distance_history = [l2distance(net.w, w_star)]
    assert not (
        return_trajectory and return_loss
    ), "only one quantity should be returned"
    it = 0
    while True:
        y_hat, nabla_w = net(X, y)

        wp_grad = grad(net.wp, nabla_w, X, hbf=hbf, mu=mu, eta=eta)
        wm_grad = grad(net.wm, nabla_w, X, p=False, hbf=hbf, mu=mu, eta=eta)

        net.wp -= eta_euler * wp_grad
        net.wm -= eta_euler * wm_grad
        it += 1

        if it % factor == 0:
            if return_trajectory:
                distance_history.append(net.w.detach())

            if return_loss:
                distance_history.append(l2distance(net.w, w_star))

        if it % 1000 == 0:
            print(f"loss: {l2distance(net.w, w_star)}")

        if it >= num_it:
            break
    return distance_history, net


def hb_train(
    net,
    X,
    y,
    w_star,
    eta,
    num_it,
    eta_euler=1e-4,
    mu=0.9,
    return_trajectory=True,
    return_loss=False,
    **kwargs,
):
    assert not (
        return_trajectory and return_loss
    ), "only one quantity should be returned"
    if return_trajectory:
        distance_history = [net.w.detach()]
    else:
        distance_history = [l2distance(net.w, w_star)]
    n, d = X.shape
    # eta = eta_euler
    it = 0
    momentum_p, momentum_m = torch.zeros((d, 1)), torch.zeros((d, 1))
    while True:
        y_hat, nabla_w = net(X, y)
        wp_grad = jacobian(net.wp, nabla_w)
        wm_grad = jacobian(net.wm, nabla_w, p=False)
        momentum_p = mu * momentum_p - eta * wp_grad
        momentum_m = mu * momentum_m - eta * wm_grad

        net.wp += momentum_p
        net.wm += momentum_m
        it += 1

        if return_trajectory:
            distance_history.append(net.w.detach())

        if return_loss:
            distance_history.append(l2distance(net.w, w_star))

        if it % 1000 == 0:
            print(f"loss: {l2distance(net.w, w_star)}")

        if it >= num_it:
            break
    return distance_history, net


def find_para_diff(trajectory):
    w_final = trajectory[-1]
    w_traj = torch.cat(trajectory[:-1], dim=1)
    para_diff = w_traj - w_final
    return para_diff.T


def find_projection(x, direction):
    return torch.dot(x.reshape(-1), direction.reshape(-1)) / direction.norm(2)


def _get_traj_directions(trajectory, pca_directions=None):
    para_diff = find_para_diff(trajectory)
    get_direction = False

    if pca_directions is None:
        pca = PCA(n_components=2)
        pca.fit(para_diff)
        get_direction = True
        pca_directions = pca.components_

    x, y = [torch.from_numpy(_) for _ in pca_directions]
    para_diff = para_diff.to(x.dtype)

    x_proj = (para_diff @ x.reshape(-1, 1)).numpy().reshape(-1)
    y_proj = (para_diff @ y.reshape(-1, 1)).numpy().reshape(-1)
    if get_direction:
        return x_proj, y_proj, pca_directions
    else:
        return x_proj, y_proj


def plot_trajectory(trajectory_list, name_list, pca_directions=None, save_path=None):
    assert isinstance(trajectory_list, list), "Input variable is not a list!"
    fig, ax = plt.subplots()
    for traj, name in zip(trajectory_list, name_list):
        if pca_directions is None:
            x_proj, y_proj, pca_directions = _get_traj_directions(traj)
        else:
            x_proj, y_proj = _get_traj_directions(traj, pca_directions)
        ax.plot(x_proj, y_proj, label=f"{name}")
    plt.legend()
    plt.show()
    if save_path is not None:
        fig.savefig(save_path)


def mkdir_if_missing(directory):
    if not osp.exists(directory):
        try:
            os.makedirs(directory)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise


class AverageMeter(object):
    """Computes and stores the average and current value.

       Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'):
    mkdir_if_missing(osp.dirname(fpath))
    torch.save(state, fpath)
    if is_best:
        shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar'))


class Logger(object):
    """
    Write console output to external text file.

    Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py.
    """

    def __init__(self, fpath=None):
        self.console = sys.stdout
        self.file = None
        if fpath is not None:
            mkdir_if_missing(os.path.dirname(fpath))
            self.file = open(fpath, 'w')

    def __del__(self):
        self.close()

    def __enter__(self):
        pass

    def __exit__(self, *args):
        self.close()

    def write(self, msg):
        self.console.write(msg)
        if self.file is not None and not self.file.closed:
            self.file.write(msg)

    def flush(self):
        self.console.flush()
        if self.file is not None and not self.file.closed:
            self.file.flush()
            os.fsync(self.file.fileno())

    def close(self):
        if self.file is not None and not self.file.closed:
            self.file.flush()
            os.fsync(self.file.fileno())
            self.file.close()
