import torch
import torch.nn as nn
from ntk_net import NeuralTangentFeature

import matplotlib.pyplot as plt
from numpy import linalg as LA
import numpy as np

from trainer import NTK_MAML
from datasets import QuadraticMetaDataset

N = 85
K = 19
Inner_steps = 1
DIMS = [1, 10000, 1]#1024, 1024, 1]

def create_ntk_model():
    model = NeuralTangentFeature(DIMS)
    return model



## For linear model
# def get_est_alpha_star(phi, phi_hat):
#     for i, pi, pi_hat in zip(range(N), phi, phi_hat ):
#         upp += np.trace(LA.multi_dot([pi.T, pi, pi_hat.T, pi_hat]))
#         low += np.trace(LA.multi_dot([pi.T, pi, pi_hat.T, pi_hat, pi.T, pi]))
#     return 0.5*K*upp/low

def get_est_alpha_star(feature_matrix):
    upp = 0.
    low = 0.
    d = feature_matrix[0].shape[1]
    m = torch.cat(feature_matrix).mean().item()#.detach().cpu().numpy()
    s = torch.cat(feature_matrix).std().item() #.detach().cpu().numpy()
    print('Mean: ', m, 'std:', s)
    upp = (d)**2
    low = (d)**3
    return 0.5/(K*N*s**2)


def main():
    device = 'cuda:0'
    ALPHA_STAR_ntk =[]
    model = create_ntk_model()
    dataset = QuadraticMetaDataset(num_total_batches=N,
        nt=K, nv=10, meta_batch_size=1, nr=20,
        w_0=0, nu=0.5, p=10, sigma=1, train=True, device=device)
    loss_func = torch.nn.MSELoss()
    maml = NTK_MAML(dataset, model, 0.01, loss_func, Inner_steps, device)
    # d = get_ntk_feature(model, torch.rand(1)).shape[-1]
    # alpha_range = [-1e-4, -8e-5, -5e-5, -3e-5, -1e-5, -1e-6,  0., 1e-6, 1e-5, 3e-5, 5e-5, 8e-5, 1e-4] #d=50000 
    # alpha_range = [5e-4,  8e-4, 1e-3, 2e-3, 4e-3, 6e-3, 8e-3, 1e-2] #d=1000 
    alpha_range = [-3e-4, -1e-4, -5e-5, -1e-5, 0., 1e-5, 5e-5, 1e-4, 2e-4, 3e-4, 4e-4, 5e-4, 8e-4] #d=10000

    # alpha_range = [-1e-3, -7e-4, -3e-4, 0., 1e-4, 7e-4, 1e-3, 2e-3, 3e-3, 5e-3] #d=5000
    cum_loss  = {a:[] for a in alpha_range}
    ntk_features = {a:[] for a in alpha_range}
    for i, (train_tasks, val_tasks) in enumerate(
            iter(dataset), start=1):
        for alpha in alpha_range:
            maml.fast_lr = alpha
            batch_loss, batch_feature = maml.evaluate(train_tasks, val_tasks)

            cum_loss[alpha] += batch_loss
            ntk_features[alpha] += batch_feature

    print('-'*50)
    est_alpha = get_est_alpha_star(ntk_features[alpha])
    print(est_alpha)#.detach().cpu().item())
    print('-'*50)
    for alpha in alpha_range: 
        # import pdb; pdb.set_trace()
        cum_loss_np = torch.Tensor(cum_loss[alpha]).data.cpu().numpy()
        print(alpha)
        string = ''
        for l in cum_loss_np:
            string += '%.4f,'%l 
        print(string)


if __name__ == "__main__":
    main()


