import numpy as np
import torch
import torch.nn.functional as F



@torch.no_grad()
def log_norm(model):

    total_param_norm = 0.0
    total_param_grad = 0.0

    for (name, p) in model.named_parameters():

        param_norm = p.data.norm(2)     
        total_param_norm += param_norm ** 2

        grad_norm = p.grad.norm(2)
        total_param_grad += grad_norm ** 2


    total_param_norm = total_param_norm ** 0.5
    total_param_grad = total_param_grad ** 0.5

    return total_param_norm, total_param_grad

    # log_dict["param/total"] = total_param_norm
    # log_dict["grad/total"] = total_param_grad


@torch.no_grad()
def log_sign(model, optim_state):

    total_sign_change = 0.0
    num_params = 0

    for (name, p) in model.named_parameters():

        grad_sign = torch.sign(p.grad.data)
        pre_grad_sign = optim_state[p]["grad_sign"]
        sign_unchange = torch.sum(grad_sign * pre_grad_sign > 0.5)
        
        total_sign_change += sign_unchange
        num_params += p.numel()

    return total_sign_change / num_params



@torch.no_grad()
def log_cosine(model, optim_state):

    total_param_grad = 0.0
    total_param_v = 0.0
    total_param_dot = 0.0

    total_param_grad2 = 0.0
    total_param_v2 = 0.0
    total_param_dot2 = 0.0

    for (name, p) in model.named_parameters():

        grad_norm = p.grad.norm(2)
        total_param_grad += grad_norm ** 2
        v_norm = optim_state[p]["exp_avg_sq"].sqrt().norm(2)
        total_param_v += v_norm ** 2
        total_param_dot += torch.dot(p.grad.abs().view(-1), optim_state[p]["exp_avg_sq"].sqrt().view(-1))

        grad2 = p.grad ** 2
        v2 = optim_state[p]["exp_avg_sq"]

        grad2_norm = grad2.norm(2)
        v2_norm = v2.norm(2)
        total_param_grad2 += grad2_norm ** 2
        total_param_v2 += v2_norm ** 2
        total_param_dot2 += torch.dot(grad2.view(-1), v2.view(-1))

    total_param_grad = total_param_grad ** 0.5
    total_param_v = total_param_v ** 0.5
    total_param_cosine = total_param_dot / (total_param_grad * total_param_v)

    total_param_grad2 = total_param_grad2 ** 0.5
    total_param_v2 = total_param_v2 ** 0.5
    total_param_cosine2 = total_param_dot2 / (total_param_grad2 * total_param_v2)

    return total_param_cosine, total_param_cosine2

    # log_dict["cosine/total"] = total_param_cosine
    # log_dict["cosine2/total"] = total_param_cosine2


@torch.no_grad()
def log_river(model, optim_state):

    total_river = 0.0
    total_river2 = 0.0
    num_params = 0

    for (name, p) in model.named_parameters():
        m = optim_state[p]["exp_avg"]
        v = optim_state[p]["exp_avg_sq"]
        u = m.abs() / (v.sqrt() + 1e-10)
        u2 = m ** 2 / (v + 1e-20)
        total_river += torch.sum(u)
        total_river2 += torch.sum(u2)
        num_params += p.numel()
    
    total_river_por = total_river / num_params
    total_river2_por = total_river2 / num_params

    return total_river_por, total_river2_por

    # log_dict["river/total"] = total_river_por
    # log_dict["river2/total"] = total_river2_por
