import torch
import torch.nn as nn
import math
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import random
import argparse
import pathlib

from BN_match_target_permute import *

def main():
    global args
    parser = argparse.ArgumentParser(description='Estimating the approximation error of approximating a given target layer with two random layers and trainable batch normalization parameters.')
    #parser.add_argument('--e', type=float, default=0.01, metavar='eps', help='Allowed approximation error for each target parameter (default=0.01).')
    #parser.add_argument('--rep', type=int, default=50, metavar='nbrRep',
    #                    help='Number of independent repetitions of LT construction for a given target (default: 5).')
    #parser.add_argument('--model', type=str, default="resnet.pt",
    #                    help='Path to target model.')
    parser.add_argument('--nin', type=int, default=10,
                        help='Number of target input features.')
    parser.add_argument('--nout', type=int, default=10,
                        help='Number of target output features.')
    parser.add_argument('--k1', type=int, default=1,
                        help='First target kernel dimension. Default: 1.')
    parser.add_argument('--k2', type=int, default=1,
                        help='Second target kernel dimension. Default: 1.')
    parser.add_argument('--method', type=str, default='exact', help='Method to identify scaling parameters. Default: approximate the target network exactly. Alternative: LBFGS.')
    parser.add_argument('--seed', type=int, default=1, help='Random seed (default=1).')
    parser.add_argument('--m', type=int, default=50, help='Choose width of random intermediary layer.')
    parser.add_argument('--mt', type=int, default=1000, help='Choose width of random intermediary layer of target.')
    parser.add_argument('--adddim', type=int, default=0, help='Choose width increase of random intermediary layer in optimization.')
    parser.add_argument('--nper', type=int, default=0, help='Number of considered permutations.')
    parser.add_argument('--epochs', type=int, default=100, help='Number of LBFGS epochs.')
    parser.add_argument('--njobs', type=int, default=6, help='Number of parallel processes during additional feature construction.')
    parser.add_argument('--device', type=str, default='cpu', help='Device: cuda or cpu. Default: cpu')
    parser.add_argument('--init', type=str, default='He', help='Random parameter distribution')
    parser.add_argument('--target', type=str, default='He', help='Method to initalize random parameters of the approximation')
    parser.add_argument('--rep', type=int, default=1, help='Nbr of repetitions')
    parser.add_argument('--save', action='store_true', default=False,
                        help='Save results in file with specified filename.')
    parser.add_argument('--file', type=str, default='stats', help='Beginning of filename')

    args = parser.parse_args()

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.use_deterministic_algorithms(True)

    #assume target drawn according to He init.
    #create different target init options

    #if max(args.k1,args.k2) > 1:
        #wt = torch.randn([args.nout,args.nin,args.k1,args.k2])
        #wt = wt/math.sqrt(2*args.nin*args.k1*args.k2)
    #else:
        #wt = torch.randn([args.nout,args.nin])
        #wt = wt/math.sqrt(2*args.nin)
        #wt = torch.where(torch.rand(wt.size())<0.1,1.0,0.0)*20000
        #baseline = torch.sqrt(torch.mean(wt**2))
    def get_target(seed):
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
        if max(args.k1,args.k2) > 1:
            wtsize = (args.nout,args.nin,args.k1,args.k2)
            mfull = args.nout*(args.nin*args.k1*args.k2-1)+1
            #gamma = torch.randn(mfull)
            #gamma = torch.zeros(mfull)
            m = min(args.mt, mfull)
            #gamma[:m] = torch.randn(m)
            gamma = torch.randn(m)
            if args.target == "HePure":
                wt = torch.randn([args.nout,args.nin,args.k1,args.k2])
                wt = wt/math.sqrt(2*args.nin*args.k1*args.k2)
            elif args.target == "He":
                w01, w02 = init_He(wtsize, m, seed+1)
                #w01, w02 = init_He(wtsize, mfull, args.seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif args.target == "uni":
                w01, w02 = init_uni(wtsize, m, seed+1)
                #w01, w02 = init_He(wtsize, mfull, args.seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif args.target == "ortho":
                #w01, w02 = init_ortho(wtsize, mfull, args.seed+1)
                w01, w02 = init_ortho(wtsize, m, seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif args.target == "ER":
                #w01, w02 = init_ER(wtsize, mfull, args.seed+1)
                w01, w02 = init_ER(wtsize, m, seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif args.target == "ER_norm":
                #w01, w02 = init_ER(wtsize, mfull, args.seed+1)
                w01, w02 = init_ER_norm(wtsize, m, seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif args.target == "sign":
                #w01, w02 = init_ER(wtsize, mfull, args.seed+1)
                w01, w02 = init_sign(wtsize, m, seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif args.target == "ER_sign":
                #w01, w02 = init_ER(wtsize, mfull, args.seed+1)
                w01, w02 = init_ER_sign(wtsize, m, seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            else:
                random.seed(seed+1)
                torch.manual_seed(seed+1)
                mm = (mfull+args.nout-1) #mfull #(mfull+args.nout)
                #p = args.m/mm
                indflat = torch.randperm(mm)[:m]
                x=torch.zeros(mm)
                x[indflat] = torch.randn(len(indflat))
                wt = x.reshape(wtsize)
        else:
            wtsize = (args.nout,args.nin)
            mfull = args.nout*(args.nin-1)+1
            #gamma = torch.randn(mfull)
            #gamma = torch.zeros(mfull)
            m = min(args.mt, mfull)
            #gamma[:m] = torch.randn(m)
            gamma = torch.randn(m)
            #gamma = torch.ones(m)
            if args.target == "HePure":
                wt = torch.randn([args.nout,args.nin])
                wt = wt/math.sqrt(2*args.nin)
            elif args.target == "He":
                w01, w02 = init_He(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif args.target == "ortho":
                w01, w02 = init_ortho(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif args.target == "ER":
                w01, w02 = init_ER(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif args.target == "ERfull":
                wt = torch.randn([args.nout,args.nin])
                mask = torch.bernoulli(torch.ones(args.nout,args.nin)*m/(args.nout*args.nin))
                wt = wt*mask
                wt = wt/torch.sqrt(torch.mean(wt**2))
            elif args.target == "uni":
                w01, w02 = init_uni(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif args.target == "ER_norm":
                w01, w02 = init_ER_norm(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif args.target == "ER_sign":
                w01, w02 = init_ER_sign(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif args.target == "sign":
                w01, w02 = init_sign(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            else:
                random.seed(seed+1)
                torch.manual_seed(seed+1)
                mm = (mfull+args.nout-1) #mfull #(mfull+args.nout)
                #p = args.m/mm
                indflat = torch.randperm(mm)[:m]
                x=torch.zeros(mm)
                x[indflat] = torch.randn(len(indflat))
                wt = x.reshape(wtsize)

        baseline = torch.sqrt(torch.mean(wt**2))
        #print("baseline: ", baseline)
        wt = wt/baseline
        return wt

    output = np.zeros(args.rep)
    if args.method == "exact":
        for r in range(args.rep):
            wt = get_target(2*args.seed+r)
            #err, _, _ = proxy_target_accurate(wt, 2*args.seed+r, args.epochs, args.m, args.init,  args.device)
            err, _, _ = proxy_target_LBFGS_overparam(wt, 2*args.seed+r, args.epochs, args.m, args.init,  args.device, 0, args.nper, args.njobs, True, False)
            output[r] = err.item()
    elif args.method == "permute":
        for r in range(args.rep):
            wt = get_target(2*args.seed+r)
            err, _, _ = proxy_target_permute(args.nper, wt, 2*args.seed+r, args.epochs, args.m, args.init, args.device)
            output[r] = err.item()
    else:
        for r in range(args.rep):
            wt = get_target(2*args.seed+r)
            #err, _, _ = proxy_target_wider(wt,2*args.seed+r, args.epochs, args.m, args.device)
            err, _, _ = proxy_target_LBFGS_overparam(wt, 2*args.seed+r, args.epochs, args.m, args.init,  args.device, args.adddim, args.nper, args.njobs, False, False)
            output[r] = err.item()
    filename = args.file + "_target_" + args.target + "_nout_" + str(args.nout) + "_nin_" + str(args.nin) + "_k1_" + str(args.k1) + "_k2_" + str(args.k2) + "_init_" + args.init + "_m_" + str(args.m) + "_mt_" + str(args.mt) + "_seed_" + str(args.seed)
    if args.save:
        np.savetxt(filename + ".txt", output, delimiter=",")

    print(filename)
    print("Mean approx. error: ", np.mean(output))
    print("CI of approx. error: ", 1.96*np.std(output)/np.sqrt(args.rep))


if __name__ == '__main__':
    main()
