import copy
import sys

import numpy as np
import torch
import torch.optim as optim

sys.path.append(".")
from src.tools.sharpness_tools.math_utils import compute_loss
from src.tools.sharpness_tools.utils import load_weights, zero_grad
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler


def generate_gaussian_vec(grad_list):
    grad_vec = torch.concat([v.view(-1) for v in grad_list])
    rand_vec = torch.randn_like(grad_vec)
    rand_vec_list = []
    length = 0
    for grad in grad_list:
        param_num = grad.numel()
        rand_vec_piece = rand_vec[length: length + param_num].clone()
        rand_vec_piece = rand_vec_piece.view_as(grad)
        rand_vec_list.append(rand_vec_piece)
        length += param_num
    return rand_vec_list


def calc_norm(vec_list):
    # we need norm of the theta_star_grads as per the simplification
    vec_norm = 0.0
    for p in vec_list:
        param_norm = p.data.view(-1).norm(2)
        vec_norm += param_norm.item() ** 2
    vec_norm = vec_norm ** 0.5
    return vec_norm


def compute_radius(model, data_loader, vec, epsilon, scalar, tol=1e-6):
    # we need norm of the theta_star_grads as per the simplification
    vec_norm = calc_norm(vec)
    vec = [v / vec_norm for v in vec]

    theta_star_params = []
    for p in model.parameters():
        theta_star_params.append(copy.deepcopy(p))

    base_loss, _ = compute_loss(model, data_loader, scalar)
    eta_range = [1e-16, 1]

    # This is to find an upper_bound of eta, namely eta_range[1]
    for itr in range(10 ** 7):
        eta = eta_range[1]
        optimizer = optim.SGD(model.parameters(), eta, 0.0, 0.0)
        scalar.step(optimizer)
        scalar.update()

        curr_loss, _ = compute_loss(model, data_loader, scalar)
        with autocast():
            d = curr_loss - base_loss
        if d < epsilon:
            eta_range[1] = eta * 5
        else:
            load_weights(model, theta_star_params, vec)
            break

        if (itr + 1) % 10 == 0:
            print(
                f"{itr}, {eta:.2E}, {eta_range[0]:.2E}, {eta_range[1]:.2E}, {d:.2E}, {base_loss:.2E}, {curr_loss:.2E}")
        load_weights(model, theta_star_params, vec)

    # print(f"eta_max found: {eta_range[1]:.6E}")
    # print("Bisection search for eta")
    # print(f"{'Itr':^10} {'eta':^10} {'eta_min':^10} {'eta_max':^10} {'d':^10} {'base_loss':^10} {'curr_loss':^10}")
    for itr in range(10 ** 7):
        eta = np.mean(eta_range)
        optimizer = optim.SGD(model.parameters(), eta, 0.0, 0.0)
        scalar.step(optimizer)
        scalar.update()

        curr_loss, _ = compute_loss(model, data_loader, scalar)

        with autocast():
            d = curr_loss - base_loss
        if (epsilon - tol <= d <= epsilon + tol) or ((eta_range[1] - eta_range[0]) < 0) or (
                np.abs(eta_range[1] - eta_range[0]) < tol):
            load_weights(model, theta_star_params, vec)
            zero_grad(model)
            return np.abs(eta)
        elif d < epsilon - tol:
            eta_range[0] = eta
        else:
            eta_range[1] = eta

        load_weights(model, theta_star_params, vec)


def sample_flatness(model, data_loader, mcmc, epsilon, tol=1e-6):
    """
    Sample various directions
    :param model:
    :param data_loader:
    :param mcmc:
    :param tol:
    :return:
    """
    scalar = GradScaler()
    theta_star_params, theta_star_grads, base_loss, _ = compute_loss(model, data_loader, scalar, ascent_stats=True)
    sharpness_list = []
    for i in range(mcmc):
        # Generate a random vector that has the same shape as model_gradient
        rand_vec = generate_gaussian_vec(theta_star_grads)
        radius = compute_radius(model, data_loader, rand_vec, epsilon, scalar, tol)
        sharpness_list.append(1.0 / radius)
        print(f"Sample {i} time: {sharpness_list[-1]}")
    return sharpness_list


def sam_flatness(model, data_loader, epsilon, tol=1e-6):
    """
    Function to compute eps flatness similar to description in keskar through bisection search
    :param model: the to-be-evaluated model
    :param data_loader: the training data loader
    :param epsilon: loss deviation
    :param tol: tolerance on the loss deviation
    :return: return sharpness measure 1/alpha

    """
    scalar = GradScaler()
    theta_star_params, theta_star_grads, base_loss, _ = compute_loss(model, data_loader, scalar, ascent_stats=True)

    radius = compute_radius(model, data_loader, theta_star_grads, epsilon, scalar, tol)
    return 1.0 / radius
