import torch
import torch.nn as nn
import math
import torch.optim as optim
import numpy as np
#import matplotlib.pyplot as plt
import random
import argparse
import pathlib

from BN_match_target_permute import *

def write_result_to_csv(**kwargs):
    results = pathlib.Path("BNfeat") / "resultsScaling.csv"

    if not results.exists():
        results.write_text(
            "target, "
            "init, "
            "error, "
            "errCond, "
            "CondNumber, "
            "nout, "
            "nin, "
            "k1, "
            "k2, "
            "m, "
            "mt, "
            "method, "
            "seed\n"
        )

    #now = time.time() #.strftime("%m-%d-%y_%H:%M:%S")
    #print(now)

    with open(results, "a+") as f:
        f.write(
            (
                "{target}, "
                "{init}, "
                "{err}, "
                "{errCond}, "
                #"{errExact}, "
                "{CondNumber}, "
                #"{CondNumberExact}, "
                "{nout}, "
                "{nin}, "
                "{k1}, "
                "{k2}, "
                "{m}, "
                "{mt}, "
                "{method}, "
                #"{nper}, "
                #"{adddim}, "
                "{seed}\n"
            ).format(**kwargs)
        )

def main():
    global args
    parser = argparse.ArgumentParser(description='Estimating the approximation error of approximating a given target layer with two random layers and trainable batch normalization parameters.')
    #parser.add_argument('--e', type=float, default=0.01, metavar='eps', help='Allowed approximation error for each target parameter (default=0.01).')
    #parser.add_argument('--rep', type=int, default=50, metavar='nbrRep',
    #                    help='Number of independent repetitions of LT construction for a given target (default: 5).')
    #parser.add_argument('--model', type=str, default="resnet.pt",
    #                    help='Path to target model.')
    parser.add_argument('--nin', type=int, default=10,
                        help='Number of target input features.')
    parser.add_argument('--nout', type=int, default=10,
                        help='Number of target output features.')
    parser.add_argument('--k1', type=int, default=1,
                        help='First target kernel dimension. Default: 1.')
    parser.add_argument('--k2', type=int, default=1,
                        help='Second target kernel dimension. Default: 1.')
    parser.add_argument('--method', type=str, default='flexible', help='Method to identify scaling parameters. Default: approximate the target network exactly. Alternative: LBFGS.')
    parser.add_argument('--seed', type=int, default=1, help='Random seed (default=1).')
    parser.add_argument('--m', type=int, default=50, help='Choose width of random intermediary layer.')
    parser.add_argument('--mt', type=int, default=1000, help='Choose width of random intermediary layer of target.')
    parser.add_argument('--adddim', type=int, default=0, help='Choose width increase of random intermediary layer in optimization.')
    parser.add_argument('--nper', type=int, default=0, help='Number of considered permutations.')
    parser.add_argument('--epochs', type=int, default=100, help='Number of LBFGS epochs.')
    parser.add_argument('--njobs', type=int, default=6, help='Number of parallel processes during additional feature construction.')
    parser.add_argument('--device', type=str, default='cpu', help='Device: cuda or cpu. Default: cpu')
    #parser.add_argument('--init', type=str, default='He', help='Random parameter distribution')
    #parser.add_argument('--target', type=str, default='He', help='Method to initalize random parameters of the approximation')
    parser.add_argument('--rep', type=int, default=1, help='Nbr of repetitions')
    parser.add_argument('--save', action='store_true', default=False,
                        help='Save results in file with specified filename.')
    parser.add_argument('--file', type=str, default='feat', help='Beginning of filename')

    args = parser.parse_args()

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.use_deterministic_algorithms(True)

    #assume target drawn according to He init.
    #create different target init options

    #if max(args.k1,args.k2) > 1:
        #wt = torch.randn([args.nout,args.nin,args.k1,args.k2])
        #wt = wt/math.sqrt(2*args.nin*args.k1*args.k2)
    #else:
        #wt = torch.randn([args.nout,args.nin])
        #wt = wt/math.sqrt(2*args.nin)
        #wt = torch.where(torch.rand(wt.size())<0.1,1.0,0.0)*20000
        #baseline = torch.sqrt(torch.mean(wt**2))
    def get_target(target,seed,nin,nout):
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
        if max(args.k1,args.k2) > 1:
            wtsize = (nout,nin,args.k1,args.k2)
            mfull = nout*(nin*args.k1*args.k2-1)+1
            #gamma = torch.randn(mfull)
            #gamma = torch.zeros(mfull)
            m = min(args.mt, mfull)
            #gamma[:m] = torch.randn(m)
            gamma = torch.randn(m)
            if target == "HePure":
                wt = torch.randn([nout,nin,args.k1,args.k2])
                wt = wt/math.sqrt(2*nin*args.k1*args.k2)
            elif target == "He":
                w01, w02 = init_He(wtsize, m, seed+1)
                #w01, w02 = init_He(wtsize, mfull, args.seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif target == "uni":
                w01, w02 = init_uni(wtsize, m, seed+1)
                #w01, w02 = init_He(wtsize, mfull, args.seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif target == "ortho":
                #w01, w02 = init_ortho(wtsize, mfull, args.seed+1)
                w01, w02 = init_ortho(wtsize, m, seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif target == "ER":
                #w01, w02 = init_ER(wtsize, mfull, args.seed+1)
                w01, w02 = init_ER(wtsize, m, seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif target == "ER_norm":
                #w01, w02 = init_ER(wtsize, mfull, args.seed+1)
                w01, w02 = init_ER_norm(wtsize, m, seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif target == "sign":
                #w01, w02 = init_ER(wtsize, mfull, args.seed+1)
                w01, w02 = init_sign(wtsize, m, seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif target == "ER_sign":
                #w01, w02 = init_ER(wtsize, mfull, args.seed+1)
                w01, w02 = init_ER_sign(wtsize, m, seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif target == "id":
                w01, w02 = init_id(wtsize, m, seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif target == "precond":
                w01, w02 = init_precond(wtsize, m, seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            else:
                random.seed(seed+1)
                torch.manual_seed(seed+1)
                mm = (mfull+nout-1) #mfull #(mfull+args.nout)
                #p = args.m/mm
                indflat = torch.randperm(mm)[:m]
                x=torch.zeros(mm)
                x[indflat] = torch.randn(len(indflat))
                wt = x.reshape(wtsize)
        else:
            wtsize = (nout,nin)
            #mfull = args.nout*(args.nin-1)+1
            mfull = nout*nin
            #gamma = torch.randn(mfull)
            #gamma = torch.zeros(mfull)
            m = min(args.mt, mfull)
            #gamma[:m] = torch.randn(m)
            gamma = torch.randn(m)
            #gamma = torch.ones(m)
            if target == "HePure":
                wt = torch.randn([nout,nin])
                wt = wt/math.sqrt(2*nin)
            elif target == "He":
                w01, w02 = init_He(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif target == "ortho":
                w01, w02 = init_ortho(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif target == "ER":
                w01, w02 = init_ER(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif target == "uni":
                w01, w02 = init_uni(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif target == "ER_norm":
                w01, w02 = init_ER_norm(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif target == "ER_sign":
                w01, w02 = init_ER_sign(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif target == "ERfull":
                wt = torch.randn([nout,nin])
                mask = torch.bernoulli(torch.ones(nout,nin)*m/(nout*nin))
                wt = wt*mask
                wt = wt/torch.sqrt(torch.mean(wt**2))
            elif target == "sign":
                w01, w02 = init_sign(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif target == "id":
                w01, w02 = init_id(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif target == "precond":
                w01, w02 = init_precond(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            else:
                random.seed(seed+1)
                torch.manual_seed(seed+1)
                mm = (mfull+nout-1) #mfull #(mfull+args.nout)
                #p = args.m/mm
                indflat = torch.randperm(mm)[:m]
                x=torch.zeros(mm)
                x[indflat] = torch.randn(len(indflat))
                wt = x.reshape(wtsize)

        baseline = torch.sqrt(torch.mean(wt**2))
        #print("baseline: ", baseline)
        wt = wt/baseline
        return wt

    #feat_list = ["sign", "ER_norm", "ER_sign", "ER", "ERequi","He", "uni", "ortho"]
    #feat_list = ["He", "uni", "ortho", "ERequi", "sign", "ER_rand"]
    #feat_list_prob = ["ER_sign", "ER","sign"]
    #feat_list = ["He", "uni", "ortho", "ERequi", "sign", "ER_rand"]
    feat_list_prob = ["ER_sign", "ER", "sign"]
    #feat_list = ["He", "ortho"]
    #feat_list = ["He", "sign"]
    feat_list = ["ER_norm", "ER_sign", "ER", "ERequi","He", "uni", "ortho","precond","id"]
    feat_list = ["ER_norm", "ER_sign", "He", "uni", "ortho","precond","id"]
    target_list = ["HePure"]
    #feat_list = ["ortho","id"]

    nfeat = len(feat_list)
    out_mean = np.zeros((nfeat,nfeat))
    out_ci = np.zeros((nfeat,nfeat))
    tt = -1
    for target in target_list:
        #print(target)
        tt=tt+1
        i=-1
        for nin in [30, 80, 100, 120, 150]:
            for init in feat_list:
                #print(init)
                i=i+1
                output = np.zeros(args.rep)
                for r in range(args.rep):
                    wt = get_target(target, args.seed+r,nin,nin)#2*args.seed+r)
                    if args.method == "exact":
                        if init in feat_list_prob:
                            err, errCond, errPerm, errPermCond, CondNumber = proxy_target_LBFGS_precond(wt, 2*args.seed+r+100, args.epochs, nin*(nin-1)+1, init,  args.device, True)
                            #err, _, _ = proxy_target_LBFGS_overparam(wt, 2*args.seed+r+100, args.epochs, args.m, init,  args.device, args.adddim, args.nper, args.njobs, False, False)
                        else:
                            try:
                                err, errCond, errPerm, errPermCond, CondNumber = proxy_target_LBFGS_precond(wt, 2*args.seed+r+100, args.epochs, nin*(nin-1)+1, init,  args.device, True)
                                #err, _, _ = proxy_target_LBFGS_overparam(wt, 2*args.seed+r+100, args.epochs, args.m, init,  args.device, 0, args.nper, args.njobs, True, False)
                            except:
                                err, errCond, errPerm, errPermCond, CondNumber = proxy_target_LBFGS_precond(wt, 2*args.seed+r+100, args.epochs, nin*(nin-1)+1, init,  args.device, False)
                                #err, _, _ = proxy_target_LBFGS_overparam(wt, 2*args.seed+r+100, args.epochs, args.m, init,  args.device, args.adddim, args.nper, args.njobs, False, False)
                    else:
                        #err, _, _ = proxy_target_LBFGS_overparam(wt, 2*args.seed+r+100, args.epochs, args.m, init,  args.device, args.adddim, args.nper, args.njobs, False, False)
                        #err, errCond, errPerm, errPermCond, CondNumber = proxy_target_LBFGS_precond(wt, 2*args.seed+r+100, args.epochs, args.m, init,  args.device, False)
                        err, errCond, CondNumber = proxy_target_LBFGS_precond_noPerm(wt, 2*args.seed+r+100, args.epochs, nin*nin, init,  args.device, False)
                        #print(err)
                    #output[r] = err.item()
                    #print("hello")
                    write_result_to_csv(
                            target=target,
                            init=init,
                            err=err.item(),
                            errCond=errCond.item(),
                            CondNumber=CondNumber.item(),
                            nout=nin,
                            nin=nin,
                            k1=args.k1,
                            k2=args.k2,
                            m=(nin*nin),
                            mt=args.mt,
                            method=args.method,
                            #nper = args.nper,
                            #adddim=args.adddim,
                            seed=args.seed#,
                            #arch=str(width_vec)
                    )
                #out_mean[tt,i] = np.mean(output)
                #out_ci[tt,i] = 1.96*np.std(output)/np.sqrt(args.rep)

#    filename = args.file + "_"  + args.method + "_mean_nout_" + str(args.nout) + "_nin_" + str(args.nin) + "_k1_" + str(args.k1) + "_k2_" + str(args.k2) + "_m_" + str(args.m) + "_mt_" + str(args.mt) + "_seed_" + str(args.seed) + "_nper_" + str(args.nper) + "_adddim_" + str(args.adddim)  + "_rep_" + str(args.rep)
#    #if args.save:
#    np.savetxt(filename + ".txt", out_mean, delimiter=",")
#    filename = args.file + "_"  + args.method + "_ci_nout_"  + str(args.nout) + "_nin_" + str(args.nin) + "_k1_" + str(args.k1) + "_k2_" + str(args.k2) + "_m_" + str(args.m) + "_mt_" + str(args.mt) + "_seed_" + str(args.seed) + "_nper_" + str(args.nper) + "_adddim_" + str(args.adddim) + "_rep_" + str(args.rep)
#    #if args.save:
#    np.savetxt(filename + ".txt", out_ci, delimiter=",")

if __name__ == '__main__':
    main()
