import matplotlib.pyplot as plt
import numpy as np
from hydra import compose, initialize
import argparse

def save_kernels(model,save_path=""):
    model.eval()
    nr_layers = len(model.model.layers)
    model.model.layers[0].layer.kernel.kernel.lam =0
    kernel = model.model.layers[0].layer.kernel.kernel.cpu().detach().numpy() 
    shape = (1,kernel.shape[0],kernel.shape[1],kernel.shape[2])
    kernel = np.reshape(kernel,shape)
    
    for i in range(1,nr_layers):
        model.model.layers[i].layer.kernel.kernel.lam =0
        
        k = model.model.layers[i].layer.kernel.kernel.cpu().detach().numpy()
        k = np.reshape(k,shape)
        kernel = np.concatenate((kernel,k))
        
    print("Saving kernel")
    np.save(save_path+".npy", kernel)


def hydraload(config,path,dataset):
    import sys
    import os
    BASE_PATH = os.environ.get("BASE_PATH")
    sys.path.append(BASE_PATH)
    from train import SequenceLightningModule
    #from train import preemption_setup
    import src.utils as utils
    config = utils.train.process_config(config)
    config.train.pretrained_model_path = path
    utils.train.print_config(config, resolve=True)

    #config = preemption_setup(config)
    config.dataset.data_dir = dataset
    model = SequenceLightningModule(config)
    model = SequenceLightningModule.load_from_checkpoint(
            config.train.pretrained_model_path,
            config=config,
            strict=config.train.pretrained_model_strict_load,
        )
    return model

def FFT(kernel):
    return np.fft.rfft(kernel)
    
def fourier_kernel_plot(path,lam):
    plt.rcParams["figure.figsize"] = (5.5,3) #20,14
    loadpath = path+".npy"
    K= np.load(loadpath)
    
    K = np.sign(K)*np.maximum(0,np.abs(K)-lam)
    
    for j in range(0,K.shape[0]):
        for z in range(K.shape[1]):
            #Makes the kernel sorted by norm
            v = np.linalg.norm(K[j,z,:,:], ord=np.inf, axis=1)
            K[j,z,:,:] = K[j,z,np.argsort(v)[::-1], :]
    
    K = FFT(K)
    for j in range(0,K.shape[0]):
        
        ###Take FFT of the kernel
        
        ###
        kh = K.shape[3]//8
        L = np.linspace(1,kh,kh)
        plt.plot(L[:],K[j,0,0,:kh],label="Head"+str(0+1),color='#F28E2B',linewidth=2)
        
        #plt.legend()
        #plt.title("Long Conv kernel on Listops in layer {}".format(j+1))
        savepath = path+"FFT_{}.pdf".format(str(j+1))
        ym = np.max(np.abs(K[j,0,0,:kh]))
        plt.xlim([1,kh])
        plt.ylim([-ym,ym])
        plt.xticks([])
        plt.yticks([])
        #plt.axis("off")
        print(savepath)
        plt.tight_layout()
        plt.savefig(savepath,dpi=400)
        plt.close()

def kernel_plot(path,lam):
    plt.rcParams["figure.figsize"] = (5.5,3)
    loadpath = path+".npy"
    K= np.load(loadpath)
    
    K = np.sign(K)*np.maximum(0,np.abs(K)-lam)

    for j in range(0,K.shape[0]):
        for z in range(K.shape[1]):
            #Makes the kernel sorted by norm
            v = np.linalg.norm(K[j,z,:,:], ord=np.inf, axis=1)
            K[j,z,:,:] = K[j,z,np.argsort(v)[::-1], :]

        kh = K.shape[3]
        L = np.linspace(1,kh,kh)
        plt.plot(L[:],K[j,0,0,:kh],label="Head"+str(0+1),color="#4E79A7",linewidth=2)   #'tab:blue'
        ym = np.max(np.abs(K[j,0,0,:kh]))
        plt.ylim([-ym,ym])
        plt.xlim([1,kh])
        plt.xticks([])
        plt.yticks([])
        #plt.axis("off")
        #plt.legend()
        #plt.title("Long Conv kernel on Listops in layer {}".format(j+1))
        savepath = path+"{}.pdf".format(str(j+1))
        print(savepath)

        plt.tight_layout()
        plt.savefig(savepath,dpi=400)
        plt.close()
    
def main():
    
    parser = argparse.ArgumentParser(description="Params for the experiment")
    parser.add_argument("experiment",type=str,default="",help="The .yaml file used to run the experiment")
    parser.add_argument("checkpoint_path",type=str,default="")
    parser.add_argument("save_path",type=str)
    parser.add_argument("data_path",type=str)
    parser.add_argument("get_kernel",type=int)
    parser.add_argument("lam",type=float)

    args = parser.parse_args()

    if bool(args.get_kernel):
        assert args.experiment!="", "The experiment needs to be specifited"
        assert args.checkpoint_path!="", "The checkpoint path needs to be specified"
        assert args.data_path!="", "The data path needs to be specified"
        experiment = args.experiment
        initialize(version_base=None, config_path="./../../../configs/")
        cfg = compose(config_name="config.yaml",
                    overrides=["experiment="+experiment])

        model = hydraload(cfg,args.checkpoint_path,args.data_path)
        save_kernels(model,args.save_path)
    lam = args.lam
    kernel_plot(args.save_path,lam)
    fourier_kernel_plot(args.save_path,lam)
    
if __name__ =="__main__":
    main()

#LISTOPS
#python get_trained_model_kernel.py lra/listops-long-conv /mnt/hippo-dan/outputs/2023-01-24/21-12-41-166904/checkpoints/val/accuracy.ckpt /mnt/hippo-dan/notebooks/elliot_notebooks/plot/listops_kernel_raw_lam_0_ /data/lra_release/listops-1000/ 1 0
#python get_trained_model_kernel.py lra/listops-long-conv /mnt/hippo-dan/outputs/2023-01-24/21-36-04-641981/checkpoints/val/accuracy.ckpt /mnt/hippo-dan/notebooks/elliot_notebooks/plot/listops_kernel_raw_lam_0.003_ /data/lra_release/listops-1000/ 1 0.003
#python get_trained_model_kernel.py lra/listops-long-conv /mnt/hippo-dan/outputs/2023-01-25/13-16-43-665779/checkpoints/val/accuracy.ckpt /mnt/hippo-dan/notebooks/elliot_notebooks/plot/listops_kernel_raw_lam_0_smoothing_ /data/data/lra_release/listops-1000/ 1 0.0

#CIFAR
#python get_trained_model_kernel.py lra/cifar-long-conv /mnt/hippo-dan/outputs/2023-01-25/00-33-32-948046/checkpoints/val/accuracy.ckpt /mnt/hippo-dan/notebooks/elliot_notebooks/plot/cifar_kernel_raw_lam_0_ /data/ 1 0.0
#python get_trained_model_kernel.py lra/cifar-long-conv /mnt/hippo-dan/outputs/2023-01-25/00-35-23-958081/checkpoints/val/accuracy.ckpt /mnt/hippo-dan/notebooks/elliot_notebooks/plot/cifar_kernel_raw_lam_0.003_ /data/ 1 0.003

###PATHFINDER
#PATHFINDER
#python get_trained_model_kernel.py lra/pathfinder-long-conv /mnt/hippo-dan/outputs/2023-01-25/18-15-43-196824/checkpoints/val/accuracy.ckpt /mnt/hippo-dan/notebooks/elliot_notebooks/plot/path32_kernel_raw_lam_0.001_ /data/data/lra_release/lra_release/pathfinder32/ 1 0.001


#AMOS
#python get_trained_model_kernel.py segmentation/segmentation_3D_amos_small /mnt/safari-internal/outputs/2023-05-25/14-03-28-844280/checkpoints/last.ckpt /mnt/safari-internal/scripts/notebooks/plot/amos_kernel_plot /data/amos22_processed_normalized/ 1 0.003
#/mnt/safari-internal/scripts/notebooks/plot/amos_kernel_plot