import torch
import copy
import time
import math
import torch
import numpy as np

#from . import utils as ut
from bi_utils import *

class AdaSLS(torch.optim.Optimizer):
    def __init__(self,
                 #params,
                 hparams,
                 n,
                 bs,
                 train_iterator,
                 fhnet,
                 fcnet,
                 criterion,
                 lamba = 0.001,
                 beta=0.99,
                 momentum=0.9,
                 line_inner_steps=1,
                 beta_b = 0.9,
                 reset_option = 1,
                 gamma = 10.0,
                 eta_max_upper = 1.0,
                 c = 0.1,
                 delta = 0.0
                 ):
        #params = list(params)
        super().__init__(hparams, {})

        self.train_iterator = train_iterator
        self.fhnet = fhnet
        self.fcnet = fcnet
        self.criterion = criterion
        self.lamba = lamba
        
        # Adam
        self.momentum = momentum
        self.beta = beta
        # sls stuff
        self.line_inner_steps = line_inner_steps
        self.beta_b = beta_b
        self.reset_option = reset_option
        self.c = c
        self.gamma = gamma
        self.eta_max_upper = eta_max_upper
        self.n = n
        self.bs = bs
        self.delta = delta

        # others
        #self.params = params
        self.hparams = hparams
        
        self.state['step'] = 0
        self.state['step_size'] = eta_max_upper
        self.state['search_cost'] = 0
        self.state['n_forwards'] = 0
        self.state['n_backwards'] = 0
        self.state['gv'] = [torch.zeros(p.shape).to(p.device) for p in hparams]
        self.state['mv'] = [torch.zeros(p.shape).to(p.device) for p in hparams]
            
    def step(self, params, loss_curr, grad, data, step_size_inner, closure=None, clip_grad=False):
        # assum the grads have already been computed in grad_current
        # grad_current is the grads for hparams
        # increment step
        
        if clip_grad:
            torch.nn.utils.clip_grad_norm_(self.hparams, 0.25)
        
        # increment # forward-backward calls
        self.state['n_forwards'] += 1
        self.state['n_backwards'] += 1        
        
        # save the current parameters:
        #params_current = copy.deepcopy(self.params)
        hparams_current = copy.deepcopy(self.hparams)
        grad_norm = compute_grad_norm(grad)

        # if grad_norm < 1e-6:
        #     return 0.

        #  Gv options
        for i, g in enumerate(grad):
            self.state['gv'][i] = (1-self.beta)*(g**2) + (self.beta) * self.state['gv'][i]
            self.state['mv'][i]= (1-self.momentum)*g + (self.momentum) * self.state['mv'][i]

        pp_norm = self.get_pp_norm(grad_current=grad)
        
        # compute step size
        # =================
        #step_size = self.get_step_size(closure_deterministic, loss, params_current, grad_current, grad_norm, pp_norm, for_backtracking=False)
        #print(pp_norm)
        step_size, search_cost = self.line_search(params, hparams_current, loss_curr, grad, data,pp_norm,step_size_inner)
        #print(step_size, search_cost)
        self.try_sgd_precond_update(self.hparams, step_size, hparams_current, grad)
        
        # save the new step-size
        self.state['step_size'] = step_size
        self.state['search_cost'] = search_cost
        # compute gv stats
        gv_max = 0.    
        gv_min = np.inf 
        gv_sum =  0
        gv_count = 0   

        for i, gv in enumerate(self.state['gv']):
            gv_max = max(gv_max, gv.max().item())    
            gv_min = min(gv_min, gv.min().item())    
            gv_sum += gv.sum().item()
            gv_count += len(gv.view(-1))   
        self.state['gv_stats'] = {'gv_max':gv_max, 'gv_min':gv_min, 'gv_mean': gv_sum/gv_count}  
        self.state['grad_norm'] = grad_norm.item()
        if torch.isnan(self.hparams[0]).sum() > 0:
            raise ValueError('nans detected')
        #return loss
        self.state['step'] += 1

    def get_pp_norm(self, grad_current):
        pp_norm = 0

        for i, (g_i, gv_i) in enumerate(zip(grad_current, self.state['gv'])):
            gv_i_scaled = scale_vector(gv_i, self.beta, self.state['step'])
            pv_i = 1. / (torch.sqrt(gv_i_scaled) + 1e-8)
            pp_norm += ((g_i**2) * pv_i).sum()

        return pp_norm

    def reset_step(self, step_size):

        if self.reset_option == 1:
            step_size = self.eta_max_upper
        elif self.reset_option == 2:
            #print(step_size)
            step_size = step_size * self.gamma 
            # b is batch size, n is total number of points
            #step_size = step_size * (self.gamma**(self.bs/self.n))
        else:
            step_size = step_size 
        return step_size

    def inner_loss(self, params, hparams, data):
        images, labels = data 
        images, labels = images.cuda(), labels.cuda()
        feats = self.fhnet(images, params=hparams)
        outputs = self.fcnet(feats, params=params)
        loss = self.criterion(outputs, labels)
        l2_penalty  = 0.5 * self.lamba * sum([(p**2).sum() for p in hparams])
        loss = loss + l2_penalty
        return loss

    @torch.no_grad()
    def outer_loss(self, params, data):
        images, labels = data 
        images, labels = images.cuda(), labels.cuda()
        feats = self.fhnet(images, params=self.hparams)
        outputs = self.fcnet(feats, params=params)
        loss = self.criterion(outputs, labels)
        return loss

    def steps_sgd(self, params, step_size_inner):
        params_hat = [p.detach().clone().requires_grad_() for p in params]
        for i in range(self.line_inner_steps):
            data  = next(self.train_iterator)
            loss = self.inner_loss(params_hat, self.hparams, data)
            grads = torch.autograd.grad(loss, params_hat)
            params_hat = [p - step_size_inner * g for p,g in zip(params_hat, grads)]
        return params_hat

    def line_search(self, params, hparams_current, loss, grads, data, pp_norm, step_size_inner):
    
        # reset step size
        step_size_old = self.state["step_size"]
        
        if self.state['step'] != 0: 
            step_size = self.reset_step(step_size_old)
        else:
            step_size = step_size_old
        #print("step size before/after reset", step_size_old, step_size)
        grad_norm = compute_grad_norm(grads)
        #params_temp = [p.detach().clone() for p in params]
        #hparams_temp = [p.detach().clone() for p in hparams]
        e = 0
        n_search = 200
        if grad_norm >= 1e-8:
            for e in range(1,n_search+1):
                #hparams_new = [p.detach().clone() - step_size * g for p,g in zip(self.hparams, grads)]
                # Make a potential move
                self.try_sgd_precond_update(self.hparams,step_size, hparams_current, grads)
                params_hat = self.steps_sgd(params, step_size_inner) 
                loss_next = self.outer_loss(params_hat,data)
                #print(loss, loss_next)
                found, step_size = self.check_term(step_size, loss_next, loss, pp_norm)
                if found == 1:
                    break
            if found == 0:
                print("Watch: not found after 100 eps")
                step_size = 1e-6
                e = n_search
        return step_size, e 

    def check_term(self, 
                    step_size,
                    loss_next, 
                    loss,   
                    pp_norm,
                    delta=0.0):
        found = 0
        break_condition = loss_next - (loss - (step_size) * self.c * pp_norm)
        #print(loss_next, loss, break_condition)

        if (break_condition <= delta):
            found = 1
        else:
            # decrease the step-size by a multiplicative factor
            step_size = step_size * self.beta_b
        return found, step_size
    
    @torch.no_grad()
    def try_sgd_precond_update(self, hparams, step_size, hparams_current, grad_current):
        zipped = zip(hparams, hparams_current, grad_current, self.state['gv'], self.state['mv'])
        for p_next, p_current, g_current, gv_i, mv_i in zipped:
            gv_i_scaled = scale_vector(gv_i, self.beta, self.state['step'])
            pv_list = 1. / (torch.sqrt(gv_i_scaled) + 1e-8)
            mv_i_scaled = scale_vector(mv_i, self.momentum, self.state['step'])

            p_next.data[:] = p_current.data
            p_next.data.add_((pv_list *  mv_i_scaled), alpha=- step_size)

def scale_vector(vector, alpha, step, eps=1e-8):
    scale = (1-alpha**(max(1, step)))
    return vector / scale