import os
from networks import LeNet5Feats, ResNetFeats18, classifier
#import resnet
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision.datasets.mnist import MNIST
from torchvision.datasets import CIFAR10
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import argparse
import higher
import hypergrad as hg
#from utils import save_checkpoint
import time
import matplotlib.pyplot as plt
import pickle
import numpy as np
import random

parser = argparse.ArgumentParser(description='Bilevel Training')

parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST', 'cifar10'])
parser.add_argument('--data', type=str, default='./data')
parser.add_argument('--output_dir', type=str, default='Results_mnist')
parser.add_argument('--opt_out', type=str, default='SGD')
parser.add_argument('--opt_lower', type=int, default=2)
parser.add_argument('--opt_upper', type=int, default=2)
#parser.add_argument('--opt', type=int, default=1)
parser.add_argument('--eval_interval', type=int, default=10)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--alpha', type=float, default=10.0)
parser.add_argument('--beta', type=float, default=0.001)
parser.add_argument('--lamba', type=float, default=0.001)
parser.add_argument('--eta_max_upper', type=float, default=0.1)
parser.add_argument('--eta_max_lower', type=float, default=10.0)
parser.add_argument('--c', type=float, default=0.1)
parser.add_argument('--beta_b', type=float, default=0.9)
parser.add_argument('--gamma', type=float, default=10.0)
#parser.add_argument('--gamma', type=float, default=1000.0)
parser.add_argument('--delta', type=float, default=0.01)
parser.add_argument('--bs', type=int, default=256)
parser.add_argument('--n', type=int, default=10000)
parser.add_argument('--n_train', type=int, default=10000)
parser.add_argument('--n_test', type=int, default=60000)
parser.add_argument('--K', type=int, default=50)
parser.add_argument('--inner_steps', type=int, default=10)
parser.add_argument('--line_inner_steps', type=int, default=1)
parser.add_argument('--cg_steps', type=int, default=5)
parser.add_argument('--lower_search', type=int, default=1)
parser.add_argument('--upper_search', type=int, default=1)
parser.add_argument('--line_search_upper', type=int, default=1)
parser.add_argument('--samples', type=int, default=1)
parser.add_argument('--decr', type=str, default="sqrt")

args = parser.parse_args()

if not os.path.isdir(args.output_dir):
    os.makedirs(args.output_dir)

torch.random.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

data_test = MNIST(args.data,
                    download=True,
                    transform=transforms.Compose([
                        transforms.Resize((32, 32)),
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ]))
data_train = MNIST(args.data,
                    train=False,                    
                    download=True,
                    transform=transforms.Compose([
                        transforms.Resize((32, 32)),
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ]))

class CustomTensorIterator:
    def __init__(self, dataset, batch_size, **loader_kwargs):
        self.loader = DataLoader(dataset, batch_size=batch_size, **loader_kwargs)
        self.iterator = iter(self.loader)

    def __next__(self, *args):
        try:
            idx = next(self.iterator)
        except StopIteration:
            self.iterator = iter(self.loader)
            idx = next(self.iterator)
        return idx

# data_test - 60000
# data_train - 10000
lamba = args.lamba
train_iterator = CustomTensorIterator(data_train, batch_size=args.bs, shuffle=True, num_workers=1)
val_iterator = CustomTensorIterator(data_test, batch_size=args.bs, shuffle=True, num_workers=1)
val_loader = DataLoader(data_test, batch_size=args.bs, shuffle=True)
nd_train = len(data_train)
nd_test = len(data_test)
data_outer = []
hypernet = LeNet5Feats().cuda()
cnet = classifier(n_features=84, n_classes=10).cuda()
fhnet = higher.monkeypatch(hypernet, copy_initial_weights=True).cuda()
hparams = list(hypernet.parameters())
hparams = [hparam.requires_grad_(True) for hparam in hparams]
fcnet = higher.monkeypatch(cnet, copy_initial_weights=True).cuda()
params = list(cnet.parameters())
params = [param.requires_grad_(True) for param in params]
lower_step_track = []
search_lower_cost_track = []

criterion = torch.nn.CrossEntropyLoss().cuda()

def sps_compute_step(loss,grads,bound_lower,bound):
    grad_norm = compute_grad_norm(grads)
    s1 = loss / (grad_norm**2)
    # print(loss,grad_norm,s1)
    #print(loss,grad_norm**2, s1)
    # return np.amin(s1.numpy(),bound_lower) 
    #print("sps_compute", s1, bound)
    if s1 >= bound_lower:
        s2 = s1
    else:
        s2 = bound_lower 
    if s2 >= bound:
        return bound
    else:
        return s2

def evaluate(params, hparams):
    loss = 0.
    acc = 0.
    for images, labels in val_loader:
        images, labels = images.cuda(), labels.cuda()
        feats = fhnet(images, params=hparams)
        outputs = fcnet(feats, params=params)
        loss_temp = criterion(outputs, labels)
        loss = loss + loss_temp.item() * len(images)
        preds = outputs.data.max(1)[1]
        correct = preds.eq(labels.data.view_as(preds)).sum()
        acc = acc + correct
    acc = acc / nd_test
    loss  = loss / nd_test
    return loss, acc.cpu().numpy()

def inner_loss(params, hparams, data):
    images, labels = data 
    images, labels = images.cuda(), labels.cuda()
    feats = fhnet(images, params=hparams)
    outputs = fcnet(feats, params=params)
    loss = criterion(outputs, labels)
    l2_penalty  = 0.5 * lamba * sum([(p**2).sum() for p in hparams])
    loss = loss + l2_penalty
    return loss

# def inner_solver(params, hparams, args, step_size_old, steps=1): 

#     search_cost_lower = 0
#     params = [p.requires_grad_(True) for p in params]
#     #step_track_temp = []
#     for _ in range(steps):
#         data  = next(train_iterator)
#         loss = inner_loss(params, hparams,data)
#         grads = torch.autograd.grad(loss, params)
#         if args.lower_search == 0:
#             step_size = args.alpha
#             params = [p - step_size * g for p,g in zip(params, grads)]
#         else:
#             step_size, search_cost_lower = line_search(params, hparams, step_size_old, loss.item(), grads, args, data, upper=False)    
#             #print(step_size, search_cost_lower, loss)
#             params = [p - step_size * g for p,g in zip(params, grads)]
#             lower_step_track.append(step_size)
#     return params, step_size, search_cost_lower

def inner_solver(params, hparams, args, step_size, steps=1): 
    search_cost_lower = 0
    params = [p.requires_grad_(True) for p in params]
    #step_track_temp = []
    for i in range(steps):
        data  = next(train_iterator)
        loss = inner_loss(params, hparams,data)
        grads = torch.autograd.grad(loss, params)
        if args.lower_search == 0:
            step_size = args.alpha
            params = [p - step_size * g for p,g in zip(params, grads)]
        else:
            step_size, search_cost_lower = line_search(params, hparams, step_size, loss.item(), grads, args, data, upper=False)    
            #print(step_size, search_cost_lower, loss)
            #print(i, step_size,search_cost_lower)
            params = [p - step_size * g for p,g in zip(params, grads)]
            lower_step_track.append(step_size)
            search_lower_cost_track.append(search_cost_lower)
    return params, step_size, search_cost_lower

def check_term(step_size, loss_next, loss, grad_norm,upper=False):
    break_condition = loss_next - \
        (loss - (step_size) * args.c * grad_norm**2)
    found = 0
    if upper: 
        if (break_condition <= args.delta):
            found = 1
        else:
            step_size = step_size * args.beta_b
    else:
        if (break_condition <= 0):
            found = 1
        else:
            step_size = step_size * args.beta_b
    return found, step_size

def reset_step(step_size,args,n,upper=False,bs=256):
    if upper:
        opt = args.opt_upper
    else:
        opt = args.opt_lower
    
    # if opt == 0:
    #     step_size = step_size
    # Always start from the beginning
    if opt == 1:
        if upper:
            step_size = args.eta_max_upper
        else:
            step_size = args.eta_max_lower
    elif opt == 2:
        # b is batch size, n is total number of points
        #if upper:
        #step_size = step_size * 10.0 
        step_size = step_size * args.gamma
    else:
        #step_size = step_size * (args.gamma**(bs/n))
        step_size = step_size 
        #print(step_size)
    return step_size

def steps_sgd(params, hparams, step_size, args):
    params_hat = [p.detach().clone().requires_grad_() for p in params]
    for i in range(args.line_inner_steps):
        data  = next(train_iterator)
        loss = inner_loss(params_hat, hparams,data)
        grads = torch.autograd.grad(loss, params_hat)
        params_hat = [p - step_size * g for p,g in zip(params_hat, grads)]
    return params_hat

def line_search(params, hparams, step_size_old, loss, grads, args, data, step_size_inner=None,upper=False):
    
    # reset step size
    if upper:
        step_size = reset_step(step_size_old,args,args.n_test,upper=True,bs=args.bs)
    else:
        step_size = reset_step(step_size_old,args,args.n_train,upper=False,bs=args.bs)
    grad_norm = compute_grad_norm(grads)
    params_temp = [p.detach().clone() for p in params]
    params_hat = [p.detach().clone() for p in params]
    hparams_temp = [p.detach().clone() for p in hparams]
    e = 0
    n_search = 200
    if grad_norm >= 1e-8:
        for e in range(1,n_search + 1):
            if upper:
                hparams_new = [p.detach().clone() - step_size * g for p,g in zip(hparams, grads)]
                #params_hat, _ ,_ = inner_solver(params_temp, hparams_new,args,step_size_inner,steps=args.line_inner_steps)
                # Do one step sgd update to get params_hat
                #params_hat = steps_sgd(params_hat, hparams_new, step_size_inner, args) 
                params_hat = steps_sgd(params, hparams_new, step_size_inner, args) 
                loss_next = outer_loss([p.detach() for p in params_hat], hparams_new,data)
                found, step_size = check_term(step_size, loss_next, loss, grad_norm,upper=True)
            else:
                params_new = [p - step_size * g for p,g in zip(params_temp, grads)]
                loss_next = inner_loss(params_new,hparams_temp,data)
                found, step_size = check_term(step_size, loss_next, loss, grad_norm,upper=False)
            if found == 1:
                break
        #print(found)
        if found == 0:
            print("Watch: not found after 100 eps")
            step_size = 1e-6
            e = n_search
    return step_size, e 

# used for line-search

def outer_loss(params, hparams, data):
    images, labels = data 
    images, labels = images.cuda(), labels.cuda()
    feats = fhnet(images, params=hparams)
    outputs = fcnet(feats, params=params)
    loss = criterion(outputs, labels)
    return loss

# used in the cg function
def outer_func(params, hparams, more=False):
    
    images, labels = next(val_iterator)
    images, labels = images.cuda(), labels.cuda()
    feats = fhnet(images, params=hparams)
    outputs = fcnet(feats, params=params)
    loss = criterion(outputs, labels)
    preds = outputs.data.max(1)[1]
    correct = preds.eq(labels.data.view_as(preds)).sum()
    acc = float(correct) / labels.size(0)
    data_outer.append([images, labels])
    if more:
        return loss, acc
    else:
        return loss

def compute_grad_norm(grad_list):
    grad_norm = 0.
    for g in grad_list:
        if g is None:
            continue
        grad_norm += torch.sum(torch.mul(g, g))
    grad_norm = torch.sqrt(grad_norm)
    return grad_norm

if args.upper_search == 0:
    #outer_opt = torch.optim.SGD(lr=args.beta, params=hparams)
    if args.opt_out == "Adam":
        outer_opt = torch.optim.Adam(lr=args.beta, params=hparams)
    elif args.opt_out == "SGD":
        outer_opt = torch.optim.SGD(lr=args.beta, params=hparams)
    elif args.opt_out == "SGDM":
        outer_opt = torch.optim.SGD(lr=args.beta, momentum = 0.9, params=hparams)

total_time, val_losses, running_time, hg_norms = 0,  [], [], []
val_acc = []
step_size_lower = args.eta_max_lower
step_size_upper = args.eta_max_upper

inner_opt_cg = hg.GradientDescent(inner_loss, 1., data_or_iter=train_iterator)
decay_steps = [25,40]
upper_step_track = []
search_cost_upper_track = []
bound = args.eta_max_upper
#bound_lower = args.eta_max_upper / np.sqrt(args.K * args.inner_steps)
bound_lower = 1e-6
#print(bound_lower)

for k in range(args.K):
    
    step_start_time = time.time() 
    nes = []
    params, step_size_lower, search_cost_lower = inner_solver(params, hparams, args, step_size_lower, steps=args.inner_steps)
    #params, step_size_lower, search_cost_lower = inner_solver(params, hparams, args, args.eta_max_lower, steps=args.inner_steps)
    t1 = time.time() - step_start_time # inner loop time
    #step_size_lower = step_size_lower * 0.9 
    if args.upper_search == 0:
        outer_opt.zero_grad()
        grads, cost = hg.CG(params, hparams, args.cg_steps, inner_opt_cg, outer_func, stochastic=True, set_grad=True)
        outer_opt.step()
        # hparams = [p - step_size_upper * g for p,g in zip(hparams,grads)]
        # hparams = [p.requires_grad_() for p in hparams]
        if args.decr == "sqrt":
            step_size_upper = args.beta / np.sqrt((k+1)*args.inner_steps) 
            #step_size_upper = args.beta / np.sqrt((k+1)) 
        elif args.decr == "linear":
            step_size_upper = args.beta / ((k+1)*args.inner_steps)
            #step_size_upper = args.beta / ((k+1))
        else:
            step_size_upper = args.beta
        for g in outer_opt.param_groups:
            g['lr'] = step_size_upper
    else:
        if args.samples == 2:
            grads_step, cost_step = hg.CG(params, hparams, args.cg_steps, inner_opt_cg, outer_func, stochastic=True, set_grad=False)
            grads, cost = hg.CG(params, hparams, args.cg_steps, inner_opt_cg, outer_func, stochastic=True, set_grad=False)
        else:
            grads, cost = hg.CG(params, hparams, args.cg_steps, inner_opt_cg, outer_func, stochastic=True, set_grad=False)
        if args.line_search_upper == 1:
            if args.samples == 2:
                step_size_upper, search_cost_upper = line_search(params, hparams, step_size_upper, cost_step.item(), grads_step, args,data_outer[-1],step_size_lower, upper=True)
            else:
                step_size_upper, search_cost_upper = line_search(params, hparams, step_size_upper, cost.item(), grads, args,data_outer[-1],step_size_lower, upper=True)
            search_cost_upper_track.append(search_cost_upper)
        else:
            if args.samples == 2:
                step_size_upper = sps_compute_step(cost_step.item(),grads_step,bound_lower, bound)
            else:
                step_size_upper = sps_compute_step(cost.item(),grads,bound_lower, bound)
            print(step_size_lower,step_size_upper)
            if args.decr == "sqrt":
                bound = args.eta_max_upper / np.sqrt((k+1)*args.inner_steps)
                #bound = args.eta_max_upper / np.sqrt((k+1))
            elif args.decr == "linear":
                bound = args.eta_max_upper / ((k+1)*args.inner_steps)
                #bound = args.eta_max_upper / ((k+1))
            else:
                bound = args.eta_max_upper
        hparams = [p - step_size_upper * g for p,g in zip(hparams,grads)]
        #print(step_size_upper, search_cost_upper)
        upper_step_track.append(step_size_upper)
        #print(step_size_upper)
    #print(len(data_outer))
    data_outer = []
    with torch.no_grad():
        val_loss_temp, acc_temp = evaluate(params, hparams)
        val_losses.append(val_loss_temp)
        val_acc.append(acc_temp)
    step_time = time.time()-step_start_time
    total_time +=step_time
    running_time.append(total_time)
    hg_norms.append(torch.norm(grads[0]))
    print('outer step={} | val loss={} | val acc={} |hypergrad norm = {:.3e}'.format(k, val_losses[-1], val_acc[-1],torch.norm(grads[0])))

results = {}
results["val_losses"] = val_losses
results["val_acc"] = val_acc
results["running_time"] = running_time 
results["upper_step_track"] = upper_step_track 
results["search_cost_upper_track"] = search_cost_upper_track 
results["lower_step_track"] = lower_step_track 
results["search_lower_cost_track"] = search_lower_cost_track

results["hg_norms"] = hg_norms
p_file = 'mnist_final_lineInnerSteps{}_innerSteps{}_cgSteps{}_delta{}_lSearch{}_uSearch{}_alpha{}_beta{}_bs{}_opt{}_etaU{}_etaL{}_lamba{}_K{}_optU{}_optL{}_gamma{}_ns{}_UL{}_decr{}_s{}'.format(args.line_inner_steps,\
                args.inner_steps, args.cg_steps, args.delta, args.lower_search, args.upper_search,args.alpha, args.beta,args.bs,args.opt_out,\
                args.eta_max_upper,args.eta_max_lower, args.lamba, args.K, args.opt_upper, args.opt_lower, args.gamma, args.samples, args.line_search_upper, args.decr,args.seed)
p_file = "Results_final/" + p_file + ".pk"
pickle.dump(results, open(p_file, "wb" ))
