import torch
import random
import argparse
import importlib
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from os.path import join
from models.init import get_conf_params

plt.rcParams["figure.figsize"] = (20, 10)
plt.rcParams["font.size"] = 18
plt.rcParams["axes.labelsize"] = 20

def main(net_name, architecture, dataset, batchnorm, init, fan, device):
    print(f"Device: {device}")
    print("==> Preparing data..")
    dataset = torch.normal(mean=0.0, std=1.0, size=[128, 3, 32, 32], requires_grad=True).to(device)
    targets = torch.randint(low=0, high=10, size=[128]).to(device)
    num_classes = 10

    print(f"==> Using {net_name} ({architecture})") 
    module_ = importlib.import_module(f"models.{architecture}")
    try:
        net_class_ = getattr(module_, net_name)
    except AttributeError:
        raise AttributeError(f"Class {net_name} does not exists.")
    
    net = net_class_(num_classes=num_classes, batch_norm=batchnorm, conf_params=get_conf_params(init), fan=fan, hooks=True)
    print(net)
    net = net.to(device)
    num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print("==> Total number of parameters to be trained", num_params)
    print(f"==> Attributes\n\tClasses = {num_classes}\n\tArchitecture = {architecture}")
    print(f"\tBatchnorm = {batchnorm}\n\tInit = {init}\n\tFan = {fan}") 

    print("==> Computing signal propagation plots..")
    criterion = nn.CrossEntropyLoss()
    out = net(dataset)
    loss = criterion(out, targets)
    loss.backward()
    hooks = net.get_hooks()

    # FORWARD
    fw_spp = hooks["fw"][1:]
    # shortcut = fw_spp[0::3]
    # residual = fw_spp[1::3]
    fw_summation = fw_spp[2::3]
    bw_summation = hooks["bw"][1:]

    np.save(file=join("/tmp", f"{net_name}_{architecture.replace('_', '')}_{batchnorm}_{init}_{fan}.npy"), 
            arr=np.array([list(map(lambda x: x["var"], fw_summation)), 
                          list(map(lambda x: x["var"], bw_summation))])
            )

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("net", type=str)
    parser.add_argument("--dataset", type=str, default="CIFAR10")
    parser.add_argument("--architecture", "-a", default="short_mult") 
    parser.add_argument("--batchnorm", "-bn", action="store_true")
    parser.add_argument("--brock", action="store_true")
    parser.add_argument("--fan_out", action="store_true")
    parser.add_argument("--cuda", action="store_true")
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    # reproducibility
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    init = "brock" if args.brock else "he"
    fan = "fan_out" if args.fan_out else "fan_in"
    device = "cuda" if args.cuda else "cpu"
    main(net_name=args.net, architecture=args.architecture, dataset=args.dataset, batchnorm=args.batchnorm, init=init, fan=fan, device=device)
