import numpy as np
import math
import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F

class Buffer(nn.Module):
    def __init__(self, args, input_size=None):
        super().__init__()
        self.args = args
        self.k    = 0.03

        self.place_left = True

        if input_size is None:
            input_size = args.input_size

        # TODO(change this:)
        if args.gen:
            if 'mnist' in args.dataset:
                img_size = 784
                economy = img_size // input_size[0]
            elif 'cifar' in args.dataset:
                img_size = 32 * 32 * 3
                economy = img_size // (input_size[0] ** 2)
            elif 'imagenet' in args.dataset:
                img_size = 84 * 84 * 3
                economy = img_size // (input_size[0] ** 2)
        else:
            economy = 1

        buffer_size = args.buffer_size
        print('buffer has %d slots' % buffer_size,args.buffer_size)

        bx = torch.FloatTensor(buffer_size, *input_size).fill_(0)
        print("bx",bx.shape)
        by = torch.LongTensor(buffer_size).fill_(0)
        bt = torch.LongTensor(buffer_size).fill_(0)
        heatmap = torch.FloatTensor(buffer_size, 1,32,32).fill_(0)
        logits= torch.FloatTensor(buffer_size, 512).fill_(0)
        #if args.cuda:
        bx = bx#.cuda()#to(args.device)
        by = by#.cuda()#to(args.device)
        bt = bt#.cuda()#to(args.device)
        logits = logits#.cuda()#to(args.device)
        heatmap=heatmap
        #feature=feature#.cuda()
        self.save_logits=None
        self.save_heat = None

        self.current_index = 0
        self.n_seen_so_far = 0
        self.is_full       = 0

        # registering as buffer allows us to save the object using `torch.save`
        self.register_buffer('bx', bx)
        self.register_buffer('by', by)
        self.register_buffer('bt', bt)
        self.register_buffer('logits', logits)
        self.register_buffer('heatmap',heatmap)
        self.to_one_hot  = lambda x : x.new(x.size(0), args.n_classes).fill_(0).scatter_(1, x.unsqueeze(1), 1)
        self.arange_like = lambda x : torch.arange(x.size(0)).to(x.device)
        self.shuffle     = lambda x : x[torch.randperm(x.size(0))]

    @property
    def x(self):
        return self.bx[:self.current_index]
    def is_empty(self) -> bool:
        """
        Returns true if the buffer is empty, false otherwise.
        """
        if self.n_seen_so_far == 0:
            return True
        else:
            return False

    @property
    def y(self):
        return self.to_one_hot(self.by[:self.current_index])
    def y_idx(self):
        return self.by[:self.current_index]

    @property
    def t(self):
        return self.bt[:self.current_index]

    @property
    def valid(self):
        return self.is_valid[:self.current_index]

    def display(self, gen=None, epoch=-1):
        from torchvision.utils import save_image
        from PIL import Image

        if 'cifar' in self.args.dataset:
            shp = (-1, 3, 32, 32)
        elif 'tinyimagenet' in self.args.dataset:
            shp = (-1, 3, 64, 64)
        else:
            shp = (-1, 1, 28, 28)

        if gen is not None:
            x = gen.decode(self.x)
        else:
            x = self.x

        save_image((x.reshape(shp) * 0.5 + 0.5), 'samples/buffer_%d.png' % epoch, nrow=int(self.current_index ** 0.5))
        #Image.open('buffer_%d.png' % epoch).show()
        print(self.y.sum(dim=0))

    def add_reservoir(self, x, y, logits, t,heatmaps=None):
        n_elem = x.size(0)
       # x=x.reshape(x.size(0),1,1,-1)
        place_left = max(0, self.bx.size(0) - self.current_index)
        offset = min(place_left, n_elem)
      #  print(self.bx.shape,x[:offset].shape)
        save_logits = logits is not None
        save_heat=heatmaps is not None
        self.save_logits=logits is not None
        self.save_heat = heatmaps is not None

        # add whatever still fits in the buffer
        place_left = max(0, self.bx.size(0) - self.current_index)
        if place_left:
            offset = min(place_left, n_elem)
           # print(offset)
           # print(self.bx[self.current_index: self.current_index + offset].data.shape)
           # print(x[:offset].shape)
            self.bx[self.current_index: self.current_index + offset].data.copy_(x[:offset])
            self.by[self.current_index: self.current_index + offset].data.copy_(y[:offset])
            self.bt[self.current_index: self.current_index + offset].fill_(t)


            if save_logits:
                #print("存")
                self.logits[self.current_index: self.current_index + offset].data.copy_(logits[:offset])
                #self.feature[self.current_index: self.current_index+offset].data.copy_(feature[:offset])
            if save_heat:
                #print("存")
                self.heatmap[self.current_index: self.current_index + offset].data.copy_(heatmaps[:offset])
                #self.feature[self.current_index: self.current_index+offset].data.copy_(feature[:offset])
            self.current_index += offset
            self.n_seen_so_far += offset

            # everything was added
            if offset == x.size(0):
                return

        self.place_left = False

        # remove what is already in the buffer
        x, y = x[place_left:], y[place_left:]

        indices = torch.FloatTensor(x.size(0)).to(x.device).uniform_(0, self.n_seen_so_far).long()
        valid_indices = (indices < self.bx.size(0)).long()

        idx_new_data = valid_indices.nonzero().squeeze(-1)
        idx_buffer   = indices[idx_new_data]

        self.n_seen_so_far += x.size(0)

        if idx_buffer.numel() == 0:
            return
        #print()
        assert idx_buffer.max() < self.bx.size(0), pdb.set_trace()
        assert idx_buffer.max() < self.by.size(0), pdb.set_trace()
        assert idx_buffer.max() < self.bt.size(0), pdb.set_trace()

        assert idx_new_data.max() < x.size(0), pdb.set_trace()
        assert idx_new_data.max() < y.size(0), pdb.set_trace()

        # perform overwrite op
        self.bx[idx_buffer] = x[idx_new_data].cuda()
        self.by[idx_buffer] = y[idx_new_data].cuda()
        self.bt[idx_buffer] = t
       # pdb.set_trace()

        if save_logits:
            self.logits[idx_buffer] = logits[idx_new_data]
        if save_heat:
                self.heatmap[idx_buffer] = heatmaps[idx_new_data]
            #self.feature[idx_buffer] = feature[idx_new_data]
        return  valid_indices
    def add_reservoir_return(self, x, y, logits, t):
        n_elem = x.size(0)
       # x=x.reshape(x.size(0),1,1,-1)
        place_left = max(0, self.bx.size(0) - self.current_index)
        offset = min(place_left, n_elem)
      #  print(self.bx.shape,x[:offset].shape)
        save_logits = logits is not None
        self.save_logits=logits is not None

        # add whatever still fits in the buffer
        place_left = max(0, self.bx.size(0) - self.current_index)
        if place_left:
            offset = min(place_left, n_elem)
           # print(offset)
           # print(self.bx[self.current_index: self.current_index + offset].data.shape)
           # print(x[:offset].shape)
            self.bx[self.current_index: self.current_index + offset].data.copy_(x[:offset])
            self.by[self.current_index: self.current_index + offset].data.copy_(y[:offset])
            self.bt[self.current_index: self.current_index + offset].fill_(t)


            if save_logits:
                #print("存")
                self.logits[self.current_index: self.current_index + offset].data.copy_(logits[:offset])
                #self.feature[self.current_index: self.current_index+offset].data.copy_(feature[:offset])
            self.current_index += offset
            self.n_seen_so_far += offset

            # everything was added
            if offset == x.size(0):
                return

        self.place_left = False

        # remove what is already in the buffer
        x, y = x[place_left:], y[place_left:]

        indices = torch.FloatTensor(x.size(0)).to(x.device).uniform_(0, self.n_seen_so_far).long()
        valid_indices = (indices < self.bx.size(0)).long()

        idx_new_data = valid_indices.nonzero().squeeze(-1)
        #pdb.set_trace()
        idx_buffer   = indices[idx_new_data]

        self.n_seen_so_far += x.size(0)

        if idx_buffer.numel() == 0:
            return
        #print()
        assert idx_buffer.max() < self.bx.size(0), pdb.set_trace()
        assert idx_buffer.max() < self.by.size(0), pdb.set_trace()
        assert idx_buffer.max() < self.bt.size(0), pdb.set_trace()

        assert idx_new_data.max() < x.size(0), pdb.set_trace()
        assert idx_new_data.max() < y.size(0), pdb.set_trace()

        # perform overwrite op
        discard_x=self.bx[idx_buffer]
        discard_y=self.by[idx_buffer]
        discard_bt=self.bt[idx_buffer]
        self.bx[idx_buffer] = x[idx_new_data].cuda()
        self.by[idx_buffer] = y[idx_new_data].cuda()
        self.bt[idx_buffer] = t
       # pdb.set_trace()

        if save_logits:
            self.logits[idx_buffer] = logits[idx_new_data]
            #self.feature[idx_buffer] = feature[idx_new_data]
        return  discard_x,discard_y,discard_bt
    def measure_valid(self, generator, classifier):
        with torch.no_grad():
            # fetch valid examples
            valid_indices = self.valid.nonzero()
            valid_x, valid_y = self.bx[valid_indices], self.by[valid_indices]
            one_hot_y = self.to_one_hot(valid_y.flatten())

            hid_x = generator.idx_2_hid(valid_x)
            x_hat = generator.decode(hid_x)

            logits = classifier(x_hat)
            _, pred = logits.max(dim=1)
            one_hot_pred = self.to_one_hot(pred)
            correct = one_hot_pred * one_hot_y

            per_class_correct = correct.sum(dim=0)
            per_class_deno    = one_hot_y.sum(dim=0)
            per_class_acc     = per_class_correct.float() / per_class_deno.float()
            self.class_weight = 1. - per_class_acc
            self.valid_acc    = per_class_acc
            self.valid_deno   = per_class_deno

    def shuffle_(self):
        indices = torch.randperm(self.current_index).to(self.args.device)
        self.bx = self.bx[indices]
        self.by = self.by[indices]
        self.bt = self.bt[indices]


    def delete_up_to(self, remove_after_this_idx):
        self.bx = self.bx[:remove_after_this_idx]
        self.by = self.by[:remove_after_this_idx]
        self.br = self.bt[:remove_after_this_idx]

    def sample(self, amt, exclude_task = None, ret_ind = False):

        if self.save_logits:
            if exclude_task is not None:
                valid_indices = (self.t != exclude_task)
                valid_indices = valid_indices.nonzero().squeeze()
                bx, by, bt, logits= self.bx[valid_indices], self.by[valid_indices], self.bt[valid_indices],self.logits[valid_indices]
            else:
                bx, by, bt, logits = self.bx[:self.current_index], self.by[:self.current_index], self.bt[:self.current_index],self.logits[:self.current_index]#,self.feature[:self.current_index]

            if bx.size(0) < amt:
                if ret_ind:
                    return bx, by, logits,bt, torch.from_numpy(np.arange(bx.size(0)))
                else:
                    return bx, by, logits,bt
            else:
                indices = torch.from_numpy(np.random.choice(bx.size(0), amt, replace=False))

                #if self.args.cuda:
                indices = indices.cuda()#to(self.args.device)
         #       import pdb
          #      pdb.set_trace()

                if ret_ind:
                    return bx[indices], by[indices],logits[indices],bt[indices], indices
                else:
                    return bx[indices], by[indices],logits[indices], bt[indices]
        elif self.save_heat:
            if exclude_task is not None:
                valid_indices = (self.t != exclude_task)
                valid_indices = valid_indices.nonzero().squeeze()
                bx, by, bt, heats = self.bx[valid_indices], self.by[valid_indices], self.bt[valid_indices], \
                                     self.heatmap[valid_indices]
            else:
                bx, by, bt, heats = self.bx[:self.current_index], self.by[:self.current_index], self.bt[
                                                                                                 :self.current_index], self.heatmap[
                                                                                                                       :self.current_index]  # ,self.feature[:self.current_index]

            if bx.size(0) < amt:
                if ret_ind:
                    return bx, by, heats, bt, torch.from_numpy(np.arange(bx.size(0)))
                else:
                    return bx, by, heats, bt
            else:
                indices = torch.from_numpy(np.random.choice(bx.size(0), amt, replace=False))

                # if self.args.cuda:
                indices = indices.cuda()  # to(self.args.device)
                #       import pdb
                #      pdb.set_trace()

                if ret_ind:
                    return bx[indices], by[indices], heats[indices], bt[indices], indices
                else:
                    return bx[indices], by[indices], heats[indices], bt[indices]
        else:

            if exclude_task is not None:
                valid_indices = (self.t != exclude_task)
                valid_indices = valid_indices.nonzero().squeeze()
                bx, by, bt = self.bx[valid_indices], self.by[valid_indices], self.bt[valid_indices]
            else:
                bx, by, bt = self.bx[:self.current_index], self.by[:self.current_index], self.bt[:self.current_index]

            if bx.size(0) < amt:
                if ret_ind:
                    return bx, by, bt, torch.from_numpy(np.arange(bx.size(0)))
                else:
                    return bx, by, bt
            else:
                indices = torch.from_numpy(np.random.choice(bx.size(0), amt, replace=False))

                #if self.args.cuda:
                indices = indices.cuda()#to(self.args.device)

                if ret_ind:
                    return bx[indices], by[indices], bt[indices], indices
                else:
                    return bx[indices], by[indices], bt[indices]


    def split(self, amt):
        indices = torch.randperm(self.current_index).to(self.args.device)
        return indices[:amt], indices[amt:]

    def presample(self, amt, task = None, ret_ind = False):

        if self.save_logits:
            if task is not None:
                valid_indices = (self.t <= task)
                valid_indices = valid_indices.nonzero().squeeze()
                bx, by, bt, logits= self.bx[valid_indices], self.by[valid_indices], self.bt[valid_indices],self.logits[valid_indices]
            else:
                bx, by, bt, logits = self.bx[:self.current_index], self.by[:self.current_index], self.bt[:self.current_index],self.logits[:self.current_index]

            if bx.size(0) < amt:
                if ret_ind:
                    return bx, by, logits,bt, torch.from_numpy(np.arange(bx.size(0)))
                else:
                    return bx, by, logits,bt
            else:
                indices = torch.from_numpy(np.random.choice(bx.size(0), amt, replace=False))

                #if self.args.cuda:
                indices = indices.cuda()#to(self.args.device)

                if ret_ind:
                    return bx[indices], by[indices],logits[indices],bt[indices], indices
                else:
                    return bx[indices], by[indices],logits[indices], bt[indices]
        else:
            return 0
    def prob_index(self,distribution,amt):
        n=int(len(distribution)/2)
        valid_sum_indices=None
        for task_index in range(n):
            prob_cur_task=distribution[task_index]+distribution[task_index+1]
            va_cur_index=(self.t==task_index)
            valid_cur_indices = va_cur_index.nonzero().squeeze()

            indices = torch.from_numpy(np.random.choice(len(valid_cur_indices), int(amt*prob_cur_task), replace=False))
            valid_cur_indices=valid_cur_indices[indices]
            if valid_sum_indices is None:
                valid_sum_indices=(valid_cur_indices)
            else:
                valid_sum_indices = torch.cat((valid_cur_indices,valid_sum_indices))
        return  valid_sum_indices


    def pro_sample(self, amt, distribution, ret_ind = False):
        #task=exclude_task
        #if task>=2:
         #   import pdb
          #  pdb.set_trace()
        #
        if self.save_logits:
            #if task is not None:
              #  valid_indices = (self.t == task)
             #   valid_indices = valid_indices.nonzero().squeeze()
            #    bx, by, bt, logits= self.bx[valid_indices], self.by[valid_indices], self.bt[valid_indices],self.logits[valid_indices]
           # else:
            probi_index= self.prob_index(distribution, amt)
            return self.bx[probi_index], self.by[probi_index], self.bt[probi_index],self.logits[probi_index]
        else:
            probi_index = self.prob_index(distribution, amt)
            return self.bx[probi_index], self.by[probi_index], self.bt[probi_index]
    def prob_class_index(self,distribution,amt):
        n=int(len(distribution))
        valid_sum_indices=None
       # import pdb
        #pdb.set_trace()
        for class_index in range(n):
            prob_cur_class=distribution[class_index]#+distribution[task_index+1]
            va_cur_index=(self.by==class_index)
            valid_cur_indices = va_cur_index.nonzero().squeeze()

            indices = torch.from_numpy(np.random.choice(len(valid_cur_indices), int(amt*prob_cur_class), replace=False))
            valid_cur_indices=valid_cur_indices[indices]
            if valid_sum_indices is None:
                valid_sum_indices=(valid_cur_indices)
            else:
                valid_sum_indices = torch.cat((valid_cur_indices,valid_sum_indices))
        return  valid_sum_indices


    def pro_class_sample(self, amt, distribution, ret_ind = False):
        #task=exclude_task
        #if task>=2:
         #   import pdb
          #  pdb.set_trace()
        #
        if self.save_logits:
            #if task is not None:
              #  valid_indices = (self.t == task)
             #   valid_indices = valid_indices.nonzero().squeeze()
            #    bx, by, bt, logits= self.bx[valid_indices], self.by[valid_indices], self.bt[valid_indices],self.logits[valid_indices]
           # else:
          #  pdb.set_trace()
            probi_index= self.prob_class_index(distribution, amt)
         #   bx,by,bt,logits=self.bx.squeeze(0), self.by.squeeze(0), self.bt.squeeze(0), self.logits.squeeze(0)
          #  if probi_index is None:probi_index=torch.tensor([], device='cuda:0', dtype=torch.int64)
           # import pdb
           # pdb.set_trace()
            return self.bx[probi_index],self.by[probi_index],self.bt[probi_index],self.logits[probi_index]
        else:
            probi_index = self.prob_class_index(distribution, amt)
            return self.bx[probi_index], self.by[probi_index], self.bt[probi_index]
    def only_class_sample(self, amt, class_id= None, ret_ind = False):

        if self.save_logits:
            if class_id is not None:
                valid_indices = (self.y == class_id)
                valid_indices = valid_indices.nonzero().squeeze()
                bx, by, bt, logits= self.bx[valid_indices], self.by[valid_indices], self.bt[valid_indices],self.logits[valid_indices]
            else:
                bx, by, bt, logits = self.bx[:self.current_index], self.by[:self.current_index], self.bt[:self.current_index],self.logits[:self.current_index]

            if bx.size(0) < amt:
                if ret_ind:
                    return bx, by, logits,bt, torch.from_numpy(np.arange(bx.size(0)))
                else:
                    return bx, by, logits,bt
            else:
                indices = torch.from_numpy(np.random.choice(bx.size(0), amt, replace=False))

                #if self.args.cuda:
                indices = indices.cuda()#to(self.args.device)

                if ret_ind:
                    return bx[indices], by[indices],logits[indices],bt[indices], indices
                else:
                    return bx[indices], by[indices],logits[indices], bt[indices]
        else:
            if class_id is not None:
                valid_indices = (self.y == class_id)
                valid_indices = valid_indices.nonzero().squeeze()
                bx, by, bt = self.bx[valid_indices], self.by[valid_indices], self.bt[valid_indices]
                pdb.set_trace()
            else:
                bx, by, bt= self.bx[:self.current_index], self.by[:self.current_index], self.bt[
                                                                                                 :self.current_index]

            if bx.size(0) < amt:
                if ret_ind:
                    return bx, by, bt, torch.from_numpy(np.arange(bx.size(0)))
                else:
                    return bx, by, bt
            else:
                pdb.set_trace()
                indices = torch.from_numpy(np.random.choice(bx.size(0), amt, replace=False))

                # if self.args.cuda:
                indices = indices.cuda()  # to(self.args.device)

                if ret_ind:
                    return bx[indices], by[indices], bt[indices], indices
                else:
                    return bx[indices], by[indices], bt[indices]
    def onlysample(self, amt, task = None, ret_ind = False):

        if self.save_logits:
            if task is not None:
                valid_indices = (self.t == task)
                valid_indices = valid_indices.nonzero().squeeze()
                bx, by, bt, logits= self.bx[valid_indices], self.by[valid_indices], self.bt[valid_indices],self.logits[valid_indices]
            else:
                bx, by, bt, logits = self.bx[:self.current_index], self.by[:self.current_index], self.bt[:self.current_index],self.logits[:self.current_index]

            if bx.size(0) < amt:
                if ret_ind:
                    return bx, by, logits,bt, torch.from_numpy(np.arange(bx.size(0)))
                else:
                    return bx, by, logits,bt
            else:
                indices = torch.from_numpy(np.random.choice(bx.size(0), amt, replace=False))

                #if self.args.cuda:
                indices = indices.cuda()#to(self.args.device)

                if ret_ind:
                    return bx[indices], by[indices],logits[indices],bt[indices], indices
                else:
                    return bx[indices], by[indices],logits[indices], bt[indices]
        else:
            if task is not None:
                valid_indices = (self.t == task)
                valid_indices = valid_indices.nonzero().squeeze()
                bx, by, bt = self.bx[valid_indices], self.by[valid_indices], self.bt[valid_indices]
            else:
                bx, by, bt= self.bx[:self.current_index], self.by[:self.current_index], self.bt[
                                                                                                 :self.current_index]

            if bx.size(0) < amt:
                if ret_ind:
                    return bx, by, bt, torch.from_numpy(np.arange(bx.size(0)))
                else:
                    return bx, by, bt
            else:
                indices = torch.from_numpy(np.random.choice(bx.size(0), amt, replace=False))

                # if self.args.cuda:
                indices = indices.cuda()  # to(self.args.device)

                if ret_ind:
                    return bx[indices], by[indices], bt[indices], indices
                else:
                    return bx[indices], by[indices], bt[indices]




def get_cifar_buffer(args, hH=8, gen=None):
    args.input_size = (hH, hH)
    args.gen = True

    return Buffer(args, gen=gen)
