import torch
import torch.nn as nn
import os
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import argparse
import copy
from vgg import vgg16
from transform import *
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from data_set import *
import sys
import argparse
from struct import pack, unpack 

def get_p(m):
    c = 0
    for p in m.parameters():
        return p.grad
def get_w(m):
    for p in m.parameters():
        return p

def manual_grad(grad_out, inp):
    res = None
    for idx in range(grad_out.size()[0]):
        grad = grad_out[idx].reshape(-1, 1)
        oi   = inp[idx].reshape(1, -1)
        if res is None:
            res = grad.matmul(oi)
        else:
            res += grad.matmul(oi)
    
    return res


def element_unpack(x):
    intnum=unpack('i', pack('f', x))[0]
        
    return intnum

def get_man(x):
    return unpack('i', pack('f', x))[0] & 0x7fffff
def get_exp(x):
    return (unpack('i', pack('f', x))[0] >> 23) & 0xff

def np_float32_to_bits(r):
    rval = np.array([element_unpack(ri) for ri in r], dtype=np.int32)
    return rval

def np_float32_to_bits2D(r):
    oneD = lambda y : np_float32_to_bits(y)
    return np.array([oneD(xi) for xi in r])

def float32_get_man(x):
    print(x.dtype)
    return x & 0x7fffff

def float32_get_exp(x):
    return x >> 23 & 0xff

def data_export(x):
    man = float32_get_man(np_float32_to_bits2D(x))
    exp = float32_get_exp(np_float32_to_bits2D(x))
    return man, exp

device = torch.device("cuda:0")

blind_mat1   = torch.tensor([[1.2, 0.8, 0.0], [0.8, 1.2, 0.0], [1.1, 0.4, 0.6]], dtype=torch.float32, device=device)
unblind_mat1 = blind_mat1.inverse()
update_mat1  = torch.tensor([[1/2.0, 1/2.0, 0.0], [1/2.0, 1/2.0, 0.0], [1/2.0, 1/2.0, 0.0]], dtype=torch.float32, device=device).matmul(unblind_mat1)



blind_mat   = torch.tensor([[1.2, 0.8, 1.0], [0.8, 1.2, 1.0], [1.1, 0.4, 0.6]], dtype=torch.float32, device=device)
unblind_mat = blind_mat.inverse()


update_mat  = torch.tensor([[1/2.0, 1/2.0, 0.0], [1/2.0, 1/2.0, 0.0], [1/2.0, 1/2.0, 0.0]], dtype=torch.float32, device=device).matmul(unblind_mat)


model1 = nn.Sequential(nn.Linear(4096, 10))
model1.load_state_dict(torch.load("layer1"))
model2 = nn.Sequential(nn.Linear(4096, 10))
model2.load_state_dict(torch.load("layer1"))
model3 = nn.Sequential(nn.Linear(4096, 10))
model3.load_state_dict(torch.load("layer1"))


model4 = nn.Sequential(nn.Linear(4096, 10))
model4.load_state_dict(torch.load("layer1"))
model5 = nn.Sequential(nn.Linear(4096, 10))
model5.load_state_dict(torch.load("layer1"))
model6 = nn.Sequential(nn.Linear(4096, 10))
model6.load_state_dict(torch.load("layer1"))

model1.cuda()
model2.cuda()
model3.cuda()
model4.cuda()
model5.cuda()
model6.cuda()

oi1 = torch.load("img0")
oi2 = torch.load("img1")
#print("original input mantissa")
#man, expot = data_export(oi1.cpu().numpy())
#print(np.median(expot))
#man, expot = data_export(oi2.cpu().numpu())
#print(np.median(expot))


image1 = torch.load("linear_input_1_0")
image2 = torch.load("linear_input_1_1") 
image3 = torch.load("linear_input_1_2")
image4 = torch.load("linear_input_0_0")
image5 = torch.load("linear_input_0_1")
image6 = torch.load("linear_input_0_2")
oi4 = torch.load("oi_0_0")
oi5 = torch.load("oi_0_1")
oi6 = torch.load("oi_0_2")



grad00 = torch.load("saved_grad_1_0")
grad01 = torch.load("saved_grad_1_1")
grad02 = torch.load("saved_grad_1_2")

grad10 = torch.load("saved_grad_0_0")
grad11 = torch.load("saved_grad_0_1")
grad12 = torch.load("saved_grad_0_2")

print(torch.max(image1).item(), torch.max(image2).item(), torch.max(image3).item())
print(torch.max(image4).item(), torch.max(image5).item(), torch.max(image6).item())

unblinded1 = unblind([image1, image2, image3], unblind_mat1)
unblinded2 = unblind([image4, image5, image6], unblind_mat)

weight4 = get_w(model4)
weight5 = get_w(model5)
weight6 = get_w(model6)

print("image format test")
print(image4[0][0].item(), image5[0][0].item(), image6[0][0].item())
print(image1[0][0].item(), image2[0][0].item(), image3[0][0].item())
print(unblinded2[0][0][0].item(), unblinded2[1][0][0].item(), unblinded2[2][0][0].item())
print(unblinded1[0][0][0].item(), unblinded1[1][0][0].item(), unblinded1[2][0][0].item())

print(torch.sum(torch.abs(unblinded1[0] - unblinded2[0])))
print(torch.sum(torch.abs(unblinded1[1] - unblinded2[1])))
print(torch.sum(torch.abs(unblinded1[2] - unblinded2[2])))

print("original input data mean")
print(unblinded1[0].size())
print(torch.mean(unblinded1[0]).item(), torch.mean(unblinded1[1]).item(), torch.mean(unblinded1[2]).item())
print(torch.mean(unblinded2[0]).item(), torch.mean(unblinded2[1]).item(), torch.mean(unblinded2[2]).item())
'''
print("mantissa")
man, expot = data_export(unblinded1[0])
print(np.median(expot))
man, expot = data_export(unblinded1[1])
print(np.median(expot))
man, expot = data_export(unblinded1[2])
print(np.median(expot))

man, expot = data_export(unblinded2[0])
print(np.median(expot))
man, expot = data_export(unblinded2[1])
print(np.median(expot))
man, expot = data_export(unblinded2[2])
print(np.median(expot))
'''

print("gradout data mean")
print(grad00.size())
print(torch.mean(grad00).item(), torch.mean(grad01).item(), torch.mean(grad02).item())
print(torch.mean(grad10).item(), torch.mean(grad11).item(), torch.mean(grad12).item())

grad00np = grad00.detach().cpu().numpy()
grad01np = grad01.detach().cpu().numpy()
grad02np = grad02.detach().cpu().numpy()
print("mantissa")
man, expot = data_export(grad00np)
print(np.median(expot))
man, expot = data_export(grad01np)
print(np.median(expot))
man, expot = data_export(grad02np)
print(np.median(expot))

y1 = model1(image1)
y2 = model2(image2)
y3 = model3(image3)

y1.backward(gradient=grad00)
y2.backward(gradient=grad01)
y3.backward(gradient=grad02)

gradw1 = get_p(model1)
gradw2 = get_p(model2)
gradw3 = get_p(model3)

ub_grads1 = unblind([gradw1, gradw2, gradw3], update_mat1)

ub_ys1 = unblind([y1, y2, y3], unblind_mat1)



y4 = model4(image4)
y5 = model5(image5)
y6 = model6(image6)

ub_ys = unblind([y4, y5, y6], unblind_mat)

print(torch.sum(torch.abs(ub_ys[0] - ub_ys1[0])))
print(torch.sum(torch.abs(ub_ys[1] - ub_ys1[1])))

y4.backward(gradient=grad10)
y5.backward(gradient=grad11)
y6.backward(gradient=grad12)

grad21_manual = manual_grad(grad10, image4)

grad21 = get_p(model4)
grad22 = get_p(model5)
grad23 = get_p(model6)
print(grad21)
print(image4.size(), grad10.size())
print(grad21[0][0], grad22[0][0], grad23[0][0])
print(gradw1[0][0], gradw2[0][0], gradw3[0][0])
s = 0.0
s1= 0.0
print("------------------------")

print(grad10[0][0], grad11[0][0], grad12[0][0])

print(grad00[0][0], grad01[0][0], grad02[0][0])
print("calculated grads")
print(grad21[0][0], grad22[0][0], grad23[0][0])
print(gradw1[0][0], gradw2[0][0], gradw3[0][0])

#for i in range(16):
#    ps = image4[i][0] * grad10[i][0]
#    s += image4[i][0] * grad10[i][0]
#    ps1 = image1[i][0] * grad00[i][0]
#    s1 += ps1
#    print(image4[i][0].item(), grad10[i][0].item(), ps)
#    print(image1[i][0].item(), grad00[i][0].item(), ps1)
print("------------------------")
#print(s.item())
#print(s1.item())

ub_grads = unblind([grad21, grad22, grad23], update_mat)
print(ub_grads[0].size())
print(torch.sum(torch.abs(ub_grads[0] - ub_grads1[0])))
#print("mantissa")
#print(ub_grads[0].size())
#man, expot = data_export(ub_grads[0])
#print(np.median(expot))
#man, expot = data_export(ub_grads1[0])
#print(np.median(expot))
print(ub_grads[0][0][0], ub_grads[1][0][0], ub_grads[2][0][0])
print(ub_grads1[0][0][0], ub_grads1[1][0][0], ub_grads1[2][0][0])

#print(update_mat)
#print(update_mat1)
