import torch
import torch.nn as nn
import os
import pickle
import numpy as np
import argparse

import sys
sys.path.append("../")
from networks import get_network
from loaders import get_loaders
from AIDomains.abstract_layers import Sequential
from utils import seed_everything

while True:
    decision = input("This scripts assumes you to directly change the layers and shapes below; please check the code before running it. y/n: ")
    if decision == "y":
        break
    elif decision == "n":
        exit(0)
    else:
        print("Please input y or n. y/n: ")

seed_everything(0)
model_root = "../theory_models/ibpr_models/cifar10/dev/1/convmedbig_flat_4_4_8_500_0.00784/1695220903/"
model_name = "net_800.pth"
model_type = "ConvMedBig_w64"
input_dim = (3, 32, 32)
out_shape_last_conv = (128, 8, 8) # for convmedbig default
out_shape_last_conv = (128*2, 8, 8) # for 2x convmedbig default


# please make sure the net is consistent
torch_net = get_network(model_type, "cifar10", "cpu")

ibpr_net = torch.load(os.path.join(model_root, model_name), map_location="cpu")

# do layerwise without the normalization layer
temp_net = nn.Sequential(torch_net[0], *ibpr_net[3:9], nn.Flatten(), *ibpr_net[11:])
torch_net.load_state_dict(temp_net.state_dict())

print(torch_net)

# adjust the transpose before linear
linear_layer = None
for i, layer in enumerate(torch_net):
    if isinstance(layer, nn.Flatten) and isinstance(torch_net[i+1], nn.Linear):
        linear_layer = torch_net[i+1]
        break
W = linear_layer.weight
W = torch.from_numpy(np.transpose(W.reshape(-1, out_shape_last_conv[1], out_shape_last_conv[2], out_shape_last_conv[0]).detach().numpy(), (0, 3, 1, 2)))
W = W.reshape(W.shape[0], -1)
linear_layer.weight.data = W

# torch_net(xprime)
# ibpr_net(x)
# print((torch_net(xprime) - ibpr_net(x)).abs().mean())


parser = argparse.ArgumentParser()
args = parser.parse_args()
args.train_batch = args.test_batch = 128
args.frac_valid = None
args.grad_accu_batch = None
args.dataset = "cifar10"
loaders, input_size, input_channel, n_class = get_loaders(args)

# test natural accuracy
device = "cuda"
torch_net = torch_net.to(device)
torch_net.eval()
n_correct = 0
n_total = 0
for x, y in loaders[1]:
    x, y = x.to(device), y.to(device)
    y_pred = torch_net(x)
    y_pred = y_pred.argmax(dim=1)
    n_correct += (y_pred == y).sum().item()
    n_total += y.shape[0]
print(f"natural accuracy: {100 * n_correct / n_total:.2f}%")


torch_net = Sequential.from_concrete_network(torch_net, input_dim)
torch.save(torch_net.state_dict(), os.path.join(model_root, f"torch_{model_name.rstrip('pth')}ckpt"))