from .er import *
import math
from torch.optim import Adam
import copy
from utils.utils import get_config_attr

class FOExperienceEvolve(ExperienceReplay):
    def __init__(self, base, optimizer, input_size, cfg, goal):
        super().__init__(base, optimizer, input_size, cfg, goal)
        self.reservoir = {'x': np.zeros((self.mem_limit, input_size)),
                          'y': [None] * self.mem_limit,
                          'y_extra': [None] * self.mem_limit,
                          'x_origin': np.zeros((self.mem_limit, input_size)),
                          'x_edit_state': [None] * self.mem_limit,
                          'loss_stats': [None] * self.mem_limit,
                          'loss_stat_steps': [None] * self.mem_limit,
                          'forget': [None] * self.mem_limit,
                          'support': [None] * self.mem_limit
                          }
        self.itf_cnt = 0
        self.total_cnt = 0
        self.grad_iter = get_config_attr(cfg, 'EXTERNAL.OCL.GRAD_ITER', default=1)
        self.grad_stride = get_config_attr(cfg, 'EXTERNAL.OCL.GRAD_STRIDE', default=10.)
        self.reservoir['age'] = np.zeros(self.mem_limit)

    def observe(self, x, y, task_ids, extra=None, optimize=True, sequential=False):
        sequential = True
        global total_cnt, itf_cnt
        self.optimizer.zero_grad()
        mem_x_indices, mem_x_origin, mem_y, mem_task_ids = self.sample_mem_batch(x.device, return_indices=True)

        batch_size = x.size(0)
        if mem_x_indices is None:
            combined_x, combined_y, combined_task_ids = self.sample_mem_batch()
        else:
            mem_x, indices = mem_x_indices
            self.store_cache()
            for i in range(self.grad_iter):
                # evaluate loss on mem_x, mem_y
                mem_x.requires_grad = True
                mem_x.grad = None
                mem_x_origin.requires_grad = True

                # evaluate grad of l wrt mem
                self.optimizer.zero_grad()
                ret_dict_mem_before = self.forward_net(mem_x, mem_y, reduce=False, task_ids=task_ids)
                # grad_l = -torch.autograd.grad(torch.sum(ret_dict_mem_origin_before['loss']), mem_x_origin, retain_graph=True)[0]

                # train the model on D
                if not sequential:
                    self.get_loss_and_pseudo_update(x, y, task_ids)
                else:
                    self.train(False)
                    for b in range(batch_size):
                        x_b = x[b].unsqueeze(0)
                        if type(y) in [tuple, list]:
                            y_b = [_[b].unsqueeze(0) for _ in y]
                        else:
                            y_b = y[b].unsqueeze(0)
                        ret_dict_db = self.forward_net(x_b, y_b)
                        self.optimizer.zero_grad()
                        ret_dict_db['loss'].backward()
                        if isinstance(self.optimizer, torch.optim.SGD):
                            step_wo_state_update_sgd(self.optimizer, amp=1.)
                        elif isinstance(self.optimizer, torch.optim.Adam):
                            step_wo_state_update_adam(self.optimizer, amp=1.)
                        else:
                            raise NotImplementedError
                    self.train(True)

                ret_dict_mem_after = self.forward_net(mem_x, mem_y, reduce=False)
                if 'mask_cnts' not in ret_dict_mem_after:
                    loss_increase = (ret_dict_mem_after['loss'] - ret_dict_mem_before['loss']).mean()
                else:
                    loss_increase = (ret_dict_mem_after['loss'] - ret_dict_mem_before['loss']).sum() / \
                                        (sum(ret_dict_mem_after['mask_cnts']) + 1e-10)
                loss_increase.backward(retain_graph=False)
                grad_delta = mem_x.grad

                self.load_cache()
                proposed_mem_x = mem_x + self.grad_stride * grad_delta
                proposed_mem_x.detach_()

                #mem_x_origin = mem_x
                # self.optimizer.zero_grad()
                # ret_dict_mem_origin_before = self.forward_net(mem_x_origin, mem_y, reduce=False)
                # ret_dict_mem_update_before = self.forward_net(proposed_mem_x, mem_y, reduce=True)
                # ret_dict_mem_update_before['loss'].backward()
                # if isinstance(self.optimizer, torch.optim.SGD):
                #     step_wo_state_update_sgd(self.optimizer, amp=1.)
                # elif isinstance(self.optimizer, torch.optim.Adam):
                #     step_wo_state_update_adam(self.optimizer, amp=1.)
                # else:
                #     raise NotImplementedError
                # ret_dict_mem_origin_after = self.forward_net(mem_x_origin, mem_y, reduce=False)
                # mem_origin_loss_increase = ret_dict_mem_origin_after['loss'] - ret_dict_mem_origin_before['loss']
                # mem_origin_loss_increase = mem_origin_loss_increase.view(mem_x_origin.size(0), -1).sum(-1)
                # for b in range(mem_origin_loss_increase.size(0)):
                #     if mem_origin_loss_increase[b].item() > 0:
                #         itf_cnt += 1
                #         proposed_mem_x[b] = mem_x_origin[b]
                #     total_cnt += 1
                mem_x = proposed_mem_x
                mem_x = mem_x.detach()
            self.evolve_mem(mem_x, indices)

            # load cached parameters back
            self.load_cache()
            combined_x = torch.cat([x, mem_x], 0)
            combined_y = concat_with_padding([y, mem_y])

        ret_dict = self.forward_net(combined_x, combined_y)

        for b in range(batch_size):
            if type(y) is tuple:
                self.update_mem(x[b], [_[b] for _ in y], extra[b] if extra is not None else None)
            else:
                self.update_mem(x[b], y[b], extra[b] if extra is not None else None)

        loss = ret_dict['loss']
        if optimize:
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        return ret_dict

    def evolve_mem(self, x, indices):
        for i, idx in enumerate(indices):
            #if self.reservoir['age'][idx] < 10:
            self.reservoir['x'][idx] = x[i].cpu().numpy()
            self.reservoir['age'][idx] += 1

    def update_mem_reservoir(self, x, y, y_extra=None):
        x = x.cpu().numpy()
        if type(y) not in [list, tuple]:
            y = y_to_np(y)
        else:
            y = y_to_cpu(y)
        if type(y_extra) not in [list, tuple]:
            y_extra = y_to_np(y_extra)
        else:
            y_extra = y_to_cpu(y_extra)
        if self.example_seen < self.mem_limit:
            self.reservoir['x'][self.example_seen] = x
            self.reservoir['x_origin'][self.example_seen] = x
            self.reservoir['y'][self.example_seen] = y
            self.reservoir['y_extra'][self.example_seen] = y_extra
            self.reservoir['x_edit_state'][self.example_seen] = {}
            self.reservoir['age'][self.example_seen] = 0
        else:
            j = np.random.RandomState(self.example_seen + self.cfg.SEED).randint(0, self.example_seen)
            if j < self.mem_limit:
                self.reservoir['x'][j] = x
                self.reservoir['x_origin'][j] = x
                self.reservoir['y'][j] = y
                self.reservoir['y_extra'][j] = y_extra
                self.reservoir['x_edit_state'][j] = {}
                self.reservoir['age'][j] = 0
        self.example_seen += 1

    # def sample_mem_batch(self, device, return_indices=False, mem_k=None, seed=1,
    #                      mir=False, input_x=None, input_y=None, input_task_ids=None, mir_k=0,
    #                      skip_task=None
    #                      ):
    #     assert seed != 0
    #     random_state = self.get_random(seed)
    #     if mem_k is None:
    #         mem_k = self.mem_bs
    #     if not self.balanced:
    #         # reservoir
    #         n_max = min(self.mem_limit, self.example_seen)
    #         if n_max == 0:
    #             return None, None, None, None
    #         elif n_max < mem_k:
    #             indices = np.arange(n_max)
    #         else:
    #             indices = random_state.choice(n_max, mem_k, replace=False)
    #     else:
    #         available_index = self.get_available_index()
    #         if len(available_index) == 0:
    #             return None, None, None, None
    #         elif len(available_index) < mem_k:
    #             indices = np.array(available_index)
    #         else:
    #             indices = random_state.choice(available_index, mem_k, replace=False)
    #
    #     x = self.reservoir['x'][indices]
    #     x_origin = self.reservoir['x_origin'][indices]
    #
    #     x = torch.from_numpy(x).to(device).float()
    #     x_origin = torch.from_numpy(x_origin).to(device).float()
    #     y = index_select(self.reservoir['y'], indices, device) # [  [...], [...] ]
    #     y_extra = index_select(self.reservoir['y_extra'], indices, device)
    #     if type(y[0]) not in [list, tuple]:
    #         y_pad = concat_with_padding(y)
    #     else:
    #         y_pad = [torch.stack(_).to(device) for _ in zip(*y)]
    #     y_extra = concat_with_padding(y_extra)
    #     if not return_indices:
    #         return x, x_origin, y_pad, y_extra
    #     else:
    #         return (x, indices), x_origin, y_pad, y_extra

    def edit_mem_adam(self, p, grad, state, lr, beta1=0.9, beta2=0.999, eps=1e-8):
        #grad = p.grad.data
        if grad.is_sparse:
            raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
        # State initialization
        if len(state) == 0:
            state['step'] = 0
            # Exponential moving average of gradient values
            state['exp_avg'] = torch.zeros_like(p.data)
            # Exponential moving average of squared gradient values
            state['exp_avg_sq'] = torch.zeros_like(p.data)

        exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

        state['step'] += 1

        # Decay the first and second moment running average coefficient
        exp_avg.mul_(beta1).add_(1 - beta1, grad)
        exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

        denom = exp_avg_sq.sqrt().add_(eps)

        bias_correction1 = 1 - beta1 ** state['step']
        bias_correction2 = 1 - beta2 ** state['step']
        step_size = lr * math.sqrt(bias_correction2) / bias_correction1

        p.data.addcdiv_(-step_size, exp_avg, denom)



total_cnt, itf_cnt = 0, 0
def proj_grad(a, b, binary, always_proj):
    # project b to the direction of a

    dotp = torch.dot(a,b)
    if dotp >= 0:
        if not always_proj:
            return b
        else:
            return b - (torch.dot(-a,b) / torch.dot(a,a)) * a
    else:
        if binary: return torch.zeros_like(a)
        else: return b - (torch.dot(a,b) / torch.dot(a,a)) * a
