import time
import math
from copy import deepcopy
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib.pyplot as plt

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Palatino"],
})


DEVICE = torch.device('cuda:3')

from utils import *
import mnist
from linearization import *
from analysis import *
from trainer import *
from models import *


def train_model(net, X_tr, y_tr, X_te, y_te, 
                lr=0.2, batch_size=3, nsteps=100000, display=5000, plot=False):
    history=train(net, X_tr, y_tr, lr, nsteps, 
                      batch_size=batch_size, display=display, num_checkpoint=0, device=DEVICE)

    if plot:
        plt.semilogy(history['train'],'-')
    print('train error', validate(net, X_tr, y_tr, device=DEVICE))
    print('test error', validate(net, X_te, y_te, device=DEVICE))
    
    net = net.cpu()



def main_para_vs_flatness():
    n = 500
    X_tr, y_tr, X_te, y_te = mnist.load_data(n)
    d = 28*28
    widths = [10, 20, 30, 40, 50]
    ntries = 5

    flatness_psz = torch.zeros(len(widths), ntries)
    for i, m in enumerate(widths):
        for j in range(ntries):
            time_st = time.time()
            
            net = build_CNN(d, m)
            train_model(net, X_tr, y_tr, X_te, y_te,
                        lr=0.1, batch_size=3, nsteps=10000, display=2000, plot=False)
            
            ana = AnalyzeLargeNet(net, X_tr, y_tr)
            ana.compute_grads()
            H = ana.hessian_fro()
            flatness_psz[i,j] = H
    
            print('({:}, {:})-> took {:.1f} seconds, hessian: {:.1e}'.format(i,j, time.time()-time_st, H))
            
    plt.figure(figsize=(4,3))
    nparas = torch.tensor([num_para(build_CNN(d, m)) for m in widths]).float()

    plt.errorbar(nparas.log10(), flatness_psz.mean(dim=1), flatness_psz.std(dim=1), 
        linestyle='-', marker='o', markersize=5)

    plt.xlabel(r'$\log_{10}$(para. size)', fontsize=20)
    plt.tick_params(axis='both', labelsize=13)
    plt.ylabel(r'$\|H\|_F$', fontsize=20)
    plt.title(r'Flatness vs. parameter size', fontsize=20)

    plt.savefig('flatness_vs_paraSize_CNN.pdf', bbox_inches='tight')


def main_sample_vs_faltness():
    sample_sz = [3200]#, 6400]
    ntries = 5
    m = 50
    bz = 3
    lr = 0.1

    flatness_ssz = torch.zeros(len(sample_sz), ntries)
    for i, n in enumerate(sample_sz):
        
        for j in range(ntries):
            time_st = time.time()
            X_tr, y_tr, X_te, y_te = mnist.load_data(n)
            
            net = build_CNN(d, m)
            print(num_para(net))
            train_model(net,X_tr, y_tr, X_te, y_te,
                        lr=lr, batch_size=bz, nsteps=5000, display=1000, plot=False)
            
            ana = AnalyzeLargeNet(net, X_tr, y_tr)
            ana.compute_grads()
            H = ana.hessian_fro()
            flatness_ssz[i,j] = H
            
            print('({:}, {:})-> took {:.1f} seconds, flatness: {:.1e}'.format(i,j, time.time()-time_st, H))
    plt.figure(figsize=(4,3))

    sample_sz = torch.tensor(sample_sz)
    plt.errorbar(sample_sz.log10(), flatness_ssz.mean(dim=1), flatness_ssz.std(dim=1), linestyle='-', marker='o', linewidth=2, markersize=5)

    plt.xlabel(r' $\log_{10}$(sample size)', fontsize=20)
    plt.tick_params(axis='both', labelsize=13)
    plt.ylabel(r'$\|H\|_F$', fontsize=20)
    plt.title(r'Flatness vs. sample size', fontsize=20)

    plt.savefig('flatness_vs_samplesize_CNN.pdf', bbox_inches='tight')

if __name__ == '__main__':
    # main_para_vs_flatness()
    main_sample_vs_faltness()
