import torch
import math

from tqdm import tqdm
import argparse


def rel_l2(x, y):
    return torch.norm(x - y) / torch.norm(y)

def coserr(x, y):
    return 1 -torch.dot(x, y) / (torch.norm(x) * torch.norm(y))

def test(m, n, alpha, beta, nbit):
    W = torch.randn(m, n)
    x = torch.randn(n)

    sw = (torch.max(W, dim=1).values - torch.min(W, dim=1).values) / (2**nbit - 1)
    What = torch.round(W/sw.unsqueeze(1))*sw.unsqueeze(1)
    y = What @ x

    Mx = torch.max(x)
    mx = -torch.min(x)
    sx = (2*max(Mx, mx))/(2**nbit - 1)

    xrtn = torch.round(x/sx)*sx
    yrtn = What @ xrtn

    unitx = x / torch.norm(x) * sx
    xl = x + alpha * unitx 
    
    xf = torch.floor(xl/sx)*sx
    xc = torch.ceil(xl/sx)*sx

    p = x-(xf+sx*0.5)
    d = unitx
    score = p*4 + d*math.sqrt(n)*beta
    xdia = torch.where(score>0, xc, xf)
    xdia = torch.clamp(xdia, sx*(-2**(nbit-1)), sx*(2**(nbit-1)-1))

    xidalen = torch.norm(xdia)
    xdia = xdia / (xidalen+1e-14) * torch.norm(x)
    ydia = What @ xdia

    return rel_l2(xrtn, x), rel_l2(xdia, x), coserr(xrtn, x), coserr(xdia, x),\
            rel_l2(yrtn, y), rel_l2(ydia, y), coserr(yrtn, y), coserr(ydia, y)

def main(args):
    torch.manual_seed(args.seed)

    m, n, bit = args.m, args.n, args.bit
    alpha, beta = args.alpha, args.beta

    print(f'====  Settings  ====')
    print(f'[{m} x {n}] @ [{n} x 1] with {bit}-bit quantization')
    print(f'Extension hyperparameter (alpha): {alpha}')
    print(f'Balancing hyperparameter (beta) : {beta}')
    print(f'{args.iter} iterations')
    print(f'====================')
    
    x_rtn_l2s=[]
    x_dia_l2s=[]
    x_rtn_coss=[]
    x_dia_coss=[]
    y_rtn_l2s=[]
    y_dia_l2s=[]
    y_rtn_coss=[]
    y_dia_coss=[]
    
    for _ in tqdm(range(args.iter)):
        x_rtn_l2, x_dia_l2, x_rtn_cos, x_dia_cos, y_rtn_l2, y_dia_l2, y_rtn_cos, y_dia_cos = test(m, n, alpha, beta, bit)

        x_rtn_l2s.append(x_rtn_l2)
        x_dia_l2s.append(x_dia_l2) 
        x_rtn_coss.append(x_rtn_cos)
        x_dia_coss.append(x_dia_cos)
        y_rtn_l2s.append(y_rtn_l2)
        y_dia_l2s.append(y_dia_l2)
        y_rtn_coss.append(y_rtn_cos)
        y_dia_coss.append(y_dia_cos)

    mean_x_rtn_l2 = torch.mean(torch.tensor(x_rtn_l2s))
    mean_x_dia_l2 = torch.mean(torch.tensor(x_dia_l2s))
    mean_x_rtn_cos = torch.mean(torch.tensor(x_rtn_coss))
    mean_x_dia_cos = torch.mean(torch.tensor(x_dia_coss))
    mean_y_rtn_l2 = torch.mean(torch.tensor(y_rtn_l2s))
    mean_y_dia_l2 = torch.mean(torch.tensor(y_dia_l2s))
    mean_y_rtn_cos = torch.mean(torch.tensor(y_rtn_coss))
    mean_y_dia_cos = torch.mean(torch.tensor(y_dia_coss))

    
    print(f'==== RTN  Error ====')
    print(f'x l2: {mean_x_rtn_l2:.04f}, cos: {mean_x_rtn_cos:.04f}')
    print(f'y l2: {mean_y_rtn_l2:.04f}, cos: {mean_y_rtn_cos:.04f}')
    print(f'==== DiaQ Error ====')
    print(f'x l2: {mean_x_dia_l2:.04f}, cos: {mean_x_dia_cos:.04f}')
    print(f'y l2: {mean_y_dia_l2:.04f}, cos: {mean_y_dia_cos:.04f}')



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument('--iter', type=int, default=1000, help='number of iterations')

    parser.add_argument('--m', type=int, default=1024, help='number of rows')
    parser.add_argument('--n', type=int, default=1024, help='number of columns')
    parser.add_argument('--bit', type=int, default=4, help='bit precision')

    parser.add_argument('--alpha', type=float, default=0.5, help='extension hyperparameter in DiaQ')
    parser.add_argument('--beta', type=float, default=1.0, help='balancing hyperparameter in DiaQ')

    args = parser.parse_args()
    
    main(args)