import torch
import torch.nn as nn
import math
import torch.optim as optim
import numpy as np
#import matplotlib.pyplot as plt
import random
import argparse
import pathlib
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.patches import Rectangle


from BN_match_target_permute import *

def main():
    global args
    parser = argparse.ArgumentParser(description='Estimating the approximation error of approximating a given target layer with two random layers and trainable batch normalization parameters.')
    #parser.add_argument('--e', type=float, default=0.01, metavar='eps', help='Allowed approximation error for each target parameter (default=0.01).')
    #parser.add_argument('--rep', type=int, default=50, metavar='nbrRep',
    #                    help='Number of independent repetitions of LT construction for a given target (default: 5).')
    #parser.add_argument('--model', type=str, default="resnet.pt",
    #                    help='Path to target model.')
    parser.add_argument('--nin', type=int, default=10,
                        help='Number of target input features.')
    parser.add_argument('--nout', type=int, default=10,
                        help='Number of target output features.')
    parser.add_argument('--k1', type=int, default=1,
                        help='First target kernel dimension. Default: 1.')
    parser.add_argument('--k2', type=int, default=1,
                        help='Second target kernel dimension. Default: 1.')
    parser.add_argument('--seed', type=int, default=1, help='Random seed (default=1).')
    parser.add_argument('--m', type=int, default=50, help='Choose width of random intermediary layer.')
    parser.add_argument('--mt', type=int, default=100, help='Choose width of random intermediary layer of target.')
    parser.add_argument('--init', type=str, default='He', help='Random parameter distribution')
    #parser.add_argument('--target', type=str, default='He', help='Method to initalize random parameters of the approximation')
    parser.add_argument('--rep', type=int, default=1, help='Nbr of repetitions')
    parser.add_argument('--save', action='store_true', default=False,
                        help='Save results in file with specified filename.')
    parser.add_argument('--file', type=str, default='spectrum_', help='Beginning of filename')

    args = parser.parse_args()

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.use_deterministic_algorithms(True)

    #assume target drawn according to He init.
    #create different target init options
    def spectrum(w01,w02):
        M = torch.einsum('ik,kj->ijk', w02, w01)
        M = M.reshape([M.size(0)*M.size(1),M.size(2)])
        _, S, _ = torch.linalg.svd(M)
        condNumber = S[0]/S[-1]
        return S, condNumber

    def get_target(target,seed,nin,nout,mt):
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
        if max(args.k1,args.k2) > 1:
            wtsize = (nout,nin,args.k1,args.k2)
            mfull = nout*(nin*args.k1*args.k2-1)+1
            #gamma = torch.randn(mfull)
            #gamma = torch.zeros(mfull)
            m = min(mt, mfull)
            #gamma[:m] = torch.randn(m)
            gamma = torch.randn(m)
            if target == "HePure":
                wt = torch.randn([nout,nin,args.k1,args.k2])
                wt = wt/math.sqrt(2*nin*args.k1*args.k2)
            elif target == "He":
                w01, w02 = init_He(wtsize, m, seed+1)
                #w01, w02 = init_He(wtsize, mfull, args.seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif target == "uni":
                w01, w02 = init_uni(wtsize, m, seed+1)
                #w01, w02 = init_He(wtsize, mfull, args.seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif target == "ortho":
                #w01, w02 = init_ortho(wtsize, mfull, args.seed+1)
                w01, w02 = init_ortho(wtsize, m, seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif target == "ER":
                #w01, w02 = init_ER(wtsize, mfull, args.seed+1)
                w01, w02 = init_ER(wtsize, m, seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif target == "ER_norm":
                #w01, w02 = init_ER(wtsize, mfull, args.seed+1)
                w01, w02 = init_ER_norm(wtsize, m, seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif target == "sign":
                #w01, w02 = init_ER(wtsize, mfull, args.seed+1)
                w01, w02 = init_sign(wtsize, m, seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif target == "ER_sign":
                #w01, w02 = init_ER(wtsize, mfull, args.seed+1)
                w01, w02 = init_ER_sign(wtsize, m, seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif target == "id":
                w01, w02 = init_id(wtsize, m, seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            elif target == "precond":
                w01, w02 = init_precond(wtsize, m, seed+1)
                wt = torch.einsum('impq,m,mjkl->ijpq', w02, gamma, w01)
            else:
                random.seed(seed+1)
                torch.manual_seed(seed+1)
                mm = (mfull+nout-1) #mfull #(mfull+args.nout)
                #p = args.m/mm
                indflat = torch.randperm(mm)[:m]
                x=torch.zeros(mm)
                x[indflat] = torch.randn(len(indflat))
                wt = x.reshape(wtsize)
        else:
            wtsize = (nout,nin)
            #mfull = args.nout*(args.nin-1)+1
            mfull = nout*nin
            #gamma = torch.randn(mfull)
            #gamma = torch.zeros(mfull)
            m = min(mt, mfull)
            #gamma[:m] = torch.randn(m)
            gamma = torch.randn(m)
            #gamma = torch.ones(m)
            if target == "HePure":
                wt = torch.randn([nout,nin])
                wt = wt/math.sqrt(2*nin)
            elif target == "He":
                w01, w02 = init_He(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif target == "ortho":
                w01, w02 = init_ortho(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif target == "ER":
                w01, w02 = init_ER(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif target == "uni":
                w01, w02 = init_uni(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif target == "ER_norm":
                w01, w02 = init_ER_norm(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif target == "ER_sign":
                w01, w02 = init_ER_sign(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif target == "ERfull":
                wt = torch.randn([nout,nin])
                mask = torch.bernoulli(torch.ones(nout,nin)*m/(nout*nin))
                wt = wt*mask
                wt = wt/torch.sqrt(torch.mean(wt**2))
            elif target == "sign":
                w01, w02 = init_sign(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif target == "id":
                w01, w02 = init_id(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            elif target == "precond":
                w01, w02 = init_precond(wtsize, m, seed+1)
                wt = torch.einsum('im,m,mj->ij', w02, gamma, w01)
            else:
                random.seed(seed+1)
                torch.manual_seed(seed+1)
                mm = (mfull+nout-1) #mfull #(mfull+args.nout)
                #p = args.m/mm
                indflat = torch.randperm(mm)[:m]
                x=torch.zeros(mm)
                x[indflat] = torch.randn(len(indflat))
                wt = x.reshape(wtsize)

        baseline = torch.sqrt(torch.mean(wt**2))
        #print("baseline: ", baseline)
        wt = wt/baseline
        return wt

    def semi_circle(x):
        return np.sqrt(4-x**2)/(2*np.pi)

    def cdf_semi_circle(x):
        return (0.5+np.sqrt(4-x**2)*x/(4*np.pi)+np.arcsin(x/2)/np.pi)*2

    def cdf_semi_circle_scaled(x,p):
        b = (1+np.sqrt(p))**2
        a = (1-np.sqrt(p))**2
        return (0.5+np.sqrt((b-x)*(x-a))*(x-a)/(4*np.pi*(b-a))+np.arcsin((x-a)/(2*(b-a)))/np.pi)

    def pdf_discrete(N,p):
        x = np.linspace(0,2,N+1)
        return cdf_semi_circle(x[1:])-cdf_semi_circle(x[:(-1)])

    def MP_prob(N,p):
        b = (1+np.sqrt(p))
        a = (1-np.sqrt(p))
        x = np.linspace(a,b,N+1)+0.5*(b-a)/N
        x = x[:(-1)]
        dd = np.sqrt((x**2-a**2)*(b**2-x**2))/(np.pi*p*x)
        return dd/np.sum(dd)


    def get_init(init,seed,nin,nout,mt):
        if init in ["He", "ortho", "ER", "uni", "ER_norm", "ER_sign", "sign", "precond", "id"]:
            if init == "He":
                #w01, w02 = init_He(wt.size(), mwider, seed)
                w01, w02 = init_He((nout,nin),mt, seed)
            elif init == "ortho":
                w01, w02 = init_ortho((nout,nin), mt, seed)
            elif init == "uni":
                w01, w02 = init_uni((nout,nin), mt, seed)
            elif init == "precond":
                w01, w02 = init_precond((nout,nin), mt, seed)
            elif init == "id":
                w01, w02 = init_id((nout,nin), mt, seed)
            elif init == "ER":
                w01, w02 = init_ER((nout,nin), mt, seed)
            elif init == "ER_norm":
                w01, w02 = init_ER_norm((nout,nin), mt, seed)
            elif init == "ER_sign":
                w01, w02 = init_ER_sign((nout,nin), mt, seed)
            elif init == "sign":
                w01, w02 = init_sign((nout,nin), mt, seed)
            else: #if init == "aij":
                w01, w02 = init_aij((nout,nin), mt, seed)
        else:
            w01, w02 = init_He((nout,nin),mt, seed)
        return w01, w02

    print(args.init)
    w01, w02 = get_init("He",args.seed,args.nin,args.nout,args.mt) #init_He((n,n), n*n, 1)
    S, cn = spectrum(w01,w02*(np.sqrt(args.nin/2)))

    w01, w02 = get_init("ortho",args.seed,args.nin,args.nout,args.mt) #init_He((n,n), n*n, 1)
    S2, cn2 = spectrum(w01,w02*(np.sqrt(args.nin/2)))

    w01, w02 = get_init("uni",args.seed,args.nin,args.nout,args.mt) #init_He((n,n), n*n, 1)
    S3, cn3 = spectrum(w01,w02*(np.sqrt(args.nin/2)))

    print(cn)
    a=torch.min(S)
    b=torch.max(S)
    print(a)
    print(b)
    p=(args.nin*args.nout)/args.mt
    b = (1+np.sqrt(p))
    a = (1-np.sqrt(p))
    print(a)
    print(b)
    #plt.hist(S.detach().numpy())
    #plt.show()
    #torch.max(S)
    #counts, bins = np.histogram(S.detach().numpy())
    #plt.stairs(counts, bins)
    N=100
    S = (S-torch.min(S))*(b-a)/(torch.max(S)-torch.min(S))+a
    S2 = (S2-torch.min(S2))*(b-a)/(torch.max(S2)-torch.min(S2))+a
    S3 = (S3-torch.min(S3))*(b-a)/(torch.max(S3)-torch.min(S3))+a
    #y=pdf_discrete(N,p)
    #print(y[0])
    #y=pdf_discrete(1000,p)
    #print(y[0])
    #y=pdf_discrete(10000,p)
    #print(y[0])
    fig, ax = plt.subplots()
    font = {'family' : 'Times New Roman', #'normal',
                'weight' : 'bold',
                'size'   : 18
                }
    plt.rc('font', **font)
    ax.set_xlabel('singular values',fontsize=24,fontweight='bold')
    ax.set_ylabel('counts',fontsize=24,fontweight='bold')
    ax.legend(loc='upper right')
    font = {'family' : 'Times New Roman', #'normal',
                'weight' : 'bold',
                'size'   : 28
                }
    # plt.hist(S.detach().numpy(),bins=np.linspace(a,b,N+1),alpha = 0.5,color="cyan")
    # plt.hist(S2.detach().numpy(),bins=np.linspace(a,b,N+1),alpha = 0.3,color="purple")
    # plt.hist(S3.detach().numpy(),bins=np.linspace(a,b,N+1),alpha = 0.3,color="orange")
    # x=np.linspace(a,b,N+1)+(b-a)/N
    # N2= int(2*N/(b-a))
    # y=pdf_discrete(N2,1)
    # x2= np.linspace(0,2,N2+1)+2/N2
    # plt.plot(x[:(-1)],MP_prob(N,p)*args.mt*p,color="red")
    # plt.plot(x2[:(-1)],y*args.mt*p,color="black")
    ax.hist(S.detach().numpy(),bins=np.linspace(a,b,N+1),alpha = 0.5,color="blue")
    ax.hist(S2.detach().numpy(),bins=np.linspace(a,b,N+1),alpha = 0.3,color="gray")
    ax.hist(S3.detach().numpy(),bins=np.linspace(a,b,N+1),alpha = 0.3,color="orange")
    x=np.linspace(a,b,N+1)+(b-a)/N
    N2= int(2*N/(b-a))
    y=pdf_discrete(N2,1)
    x2= np.linspace(0,2,N2+1)+2/N2
    ax.plot(x[:(-1)],MP_prob(N,p)*args.mt*p,color="red")
    ax.plot(x2[:(-1)],y*args.mt*p,color="black")
    #plt.xlabel("singular values")
    #plt.ylabel("counts")
    handles = [Rectangle((0,0),1,1,color=c,ec="k") for c in ("blue","gray","orange")]
    labels= ["He norm","Ortho","He uniform"]
    ax.legend(handles, labels)
    plt.savefig(pathlib.Path("fig") / ('spectrum'+str(args.mt)+'.pdf'), bbox_inches='tight')
    plt.show()



if __name__ == '__main__':
    main()
