import torch
import numpy as np

from torch.utils.data import Dataset, DataLoader
from cl_gym.algorithms import ContinualAlgorithm
import matplotlib.pyplot as plt
import time

import copy
import os

def bool2idx(arr):
    idx = list()
    for i, e in enumerate(arr):
        if e == 1:
            idx.append(i)
    return np.array(idx)

class Heuristic(ContinualAlgorithm):
    # Implementation is partially based on: https://github.com/MehdiAbbanaBennani/continual-learning-ogdplus
    def __init__(self, backbone, benchmark, params, **kwargs):
        self.backbone = backbone
        self.benchmark = benchmark
        self.params = params
        self.alpha = self.params['alpha']
        self.weight_all = list()
        self.optim_time = dict()
        self.etc_prepare_time = dict()
        self.grad_calculation_time = dict()
        self.overall_training_time = dict()
        self.true_loss = dict()
        self.expected_loss = dict()
        self.debug = False
        super(Heuristic, self).__init__(backbone, benchmark, params, **kwargs)

    def before_training_epoch(self):
        if hasattr(super(), "before_training_task"):
            super().before_training_task()
        self.training_start_time = time.time()

    def training_epoch_end(self):
        training_time = time.time() - self.training_start_time
        self.overall_training_time[self.current_task].append(training_time)
        return super().training_epoch_end()


    def get_num_current_classes(self, task):
        if task is None:
            return self.benchmark.num_classes_per_split
        else:
            if len(self.benchmark.class_idx) - self.benchmark.num_classes_per_split * task < 0:
                return len(self.benchmark.class_idx) - self.benchmark.num_classes_per_split * (task-1)
            else:
                return self.benchmark.num_classes_per_split

    def memory_indices_selection(self, task):
        ## update self.benchmark.memory_indices_train[task] with len self.benchmark.per_task_memory_examples
        indices_train = np.arange(self.per_task_memory_examples)
        assert len(indices_train) == self.per_task_memory_examples
        self.benchmark.memory_indices_train[task] = indices_train[:]

    def update_episodic_memory(self):
        # self.memory_indices_selection(self.current_task)
        self.episodic_memory_loader, _ = self.benchmark.load_memory_joint(self.current_task,
                                                                          batch_size=self.params['batch_size_memory'],
                                                                          shuffle=True,
                                                                          pin_memory=True)
        self.episodic_memory_iter = iter(self.episodic_memory_loader)

    def before_training_task(self):
        if hasattr(super(), "before_training_task"):
            super().before_training_task()
        self.weight_for_task = list()
        self.classwise_mean_grad = list()
        self.optim_time[self.current_task] = dict()
        self.etc_prepare_time[self.current_task] = dict()
        self.grad_calculation_time[self.current_task] = dict()
        self.overall_training_time[self.current_task] = list()
        self.true_loss[self.current_task] = dict()
        self.expected_loss[self.current_task] = dict()

    def sample_batch_from_memory(self):
        try:
            batch = next(self.episodic_memory_iter)
        except StopIteration:
            self.episodic_memory_iter = iter(self.episodic_memory_loader)
            batch = next(self.episodic_memory_iter)
        
        device = self.params['device']
        inp, targ, task_id, *_ = batch
        if isinstance(inp, list):
            inp = [x.to(device) for x in inp]
        else:
            inp = inp.to(device)

        return inp, targ.to(device), task_id.to(device), _

    def training_task_end(self):
        """
        Select what to store in the memory in this step
        """
        print("training_task_end")
        self.weight_all.append(self.weight_for_task)
        super().training_task_end()

    def get_loss_grad(self):
        raise NotImplementedError
    
    def get_loss_grad_all(self):
        raise NotImplementedError

    def measure_loss(self):
        raise NotImplementedError

    def converter(self):
        raise NotImplementedError

    def training_step(self):
        raise NotImplementedError

    def prepare_train_loader(self, task_id, solver=None, epoch=0):
        """
        This function prepares the training data loader for a given task. It computes the gradient for memory replay and individual sample gradient for all current data. It selects the sample that minimizes the loss and standard deviation for each batch for training. It also handles the decay of the alpha parameter and updates the sample weights based on the solver's output. The function returns a DataLoader object for the training data.

        Parameters:
            task_id (int): The ID of the task for which the training data loader is to be prepared.
            solver (function, optional): The solver function to be used for optimizing the weights. If not provided, the default solver is used.
            epoch (int, optional): The current epoch number. Defaults to 0.

        Returns:
            DataLoader: The DataLoader object for the training data.
        """
        prepare_start_time = time.time()
        num_workers = self.params.get('num_dataloader_workers', torch.get_num_threads())
        if task_id == 1: # no memory need
            return self.benchmark.load(task_id, self.params['batch_size_train'],
                                    num_workers=num_workers, pin_memory=True)[0]
        
        if self.alpha == 0 and not self.params.get('alpha_debug', False):
            return self.benchmark.load(task_id, self.params['batch_size_train'],
                                    num_workers=num_workers, pin_memory=True)[0]
        
        if epoch <= 1:
            self.original_seq_indices_train = self.benchmark.seq_indices_train[task_id]
            if hasattr(self.benchmark.trains[task_id], "sensitive"):
                print(f"Num. of sensitives: {(self.benchmark.trains[task_id].sensitive[self.original_seq_indices_train] != self.benchmark.trains[task_id].targets[self.original_seq_indices_train]).sum().item()}")
        else:
            self.benchmark.seq_indices_train[task_id] = copy.deepcopy(self.original_seq_indices_train)
        self.non_select_indexes = list(range(len(self.benchmark.seq_indices_train[task_id])))

        grad_start_time = time.time()
        loss_group, grad_group, grad_data, new_batch = self.get_loss_grad_all(task_id) 
        grad_data_prev, grad_data_current = grad_data

        if self.alpha == 0:
            return self.benchmark.load(task_id, self.params['batch_size_train'],
                                    num_workers=num_workers, pin_memory=True)[0]

        print(f"{loss_group=}")
        
        if self.params.get('alpha_decay', False) and epoch in self.params.get('learning_rate_decay_epoch', []): # decay
            self.alpha = self.alpha / 10
        if not self.params.get('all_layer_gradient', False) or self.params.get('old', False):
            converter_out = self.converter(loss_group, self.alpha, grad_group, grad_data_current, task=task_id, grad_data_prev=grad_data_prev)
        else:
            configs = grad_data_current
            converter_out = self.converter_LP_lower(configs, loss_group, self.alpha, task=task_id)
        optim_in = list()
        for i, e in enumerate(converter_out):
            if i % 2 == 0:
                e_np = e.cpu().detach().numpy().astype('float64')
            else:
                e_np = e.view(-1).cpu().detach().numpy().astype('float64')
            optim_in.append(e_np)
        gradient_calculation_time = time.time()
        grad_time = gradient_calculation_time - grad_start_time
        print(f"Elapsed time(grad):{np.round(grad_time, 3)}")
        self.grad_calculation_time[self.current_task][epoch] = grad_time        


        weight = solver(*optim_in)
        solver_calculation_time = time.time()
        solver_time = solver_calculation_time - gradient_calculation_time
        print(f"Elapsed time(optim):{np.round(solver_time, 3)}")
        self.optim_time[self.current_task][epoch] = solver_time        

        print(f"Fairness:{np.matmul(optim_in[0], weight)-optim_in[1]}")
        if self.debug:
            group_loss = loss_group # (group_num)
            group_grad = grad_group # (group_num) * (weight&bias 차원수)
            data_grad = grad_data_current.T # (weight&bias 차원수) * (current step data 후보수)
            weight_torch = torch.Tensor(weight)
            expected_loss = group_loss - self.alpha * torch.matmul(group_grad, torch.matmul(data_grad, weight_torch))
            self.expected_loss[self.current_task][epoch] = expected_loss
            # self.true_loss[self.current_task][epoch-1] = loss_group
            print(f"Current class expected loss:{expected_loss}")

        tensor_weight = torch.tensor(np.array(weight), dtype=torch.float32)


        # self.benchmark.update_sample_weight(task_id, tensor_weight)
        # Need to update self.benchmark.seq_indices_train[task] - to ignore weight = 0
        drop_threshold = 0.05
        selected_idx = np.array(weight)>drop_threshold
        updated_seq_indices = np.array(self.benchmark.seq_indices_train[task_id])[selected_idx]
        # Good to be len(updated_seq_indices) % params['batch_size_train'] == 0 --> perturb threshold a bit
        # this is more reliable than drop_last = True
        if len(updated_seq_indices) % self.params['batch_size_train'] > 0:
            num_candidate = np.sum(np.logical_not(selected_idx))
            num_to_add = min(-len(updated_seq_indices) % self.params['batch_size_train'], num_candidate)
            # added index will be eventually ignored by small weight, this process only affect on batchnorm
            # just adding all weight-zero indices can causes back-prop error if all the samples in any batch is zero
            add_idx = np.random.choice(np.where(np.logical_not(selected_idx))[0], num_to_add, replace=False)
            selected_idx = np.logical_or(selected_idx, np.isin(np.arange(len(weight)), add_idx))
            updated_seq_indices = np.array(self.benchmark.seq_indices_train[task_id])[selected_idx]

        # modified
        print(f"{len(updated_seq_indices)=}")
        self.benchmark.update_sample_weight(task_id, tensor_weight)
        # self.benchmark.seq_indices_train[task_id] = updated_seq_indices.tolist()
        # return self.benchmark.load(task_id, self.params['batch_size_train'],
        #                            num_workers=num_workers, pin_memory=True)[0]

        # but this parameter is not used in rest of the code
        if hasattr(self.benchmark.trains[task_id], "sensitive"):
            print(f"sensitive samples / selected samples = {(self.benchmark.trains[task_id].sensitive[updated_seq_indices] != self.benchmark.trains[task_id].targets[updated_seq_indices]).sum().item()} / {len(updated_seq_indices)}")

        # new_batch_by_args = [list(x) for x in zip(*new_batch)]
        new_batch_by_args = [list() for _ in new_batch[0]]
        for items in new_batch:
            for i, item in enumerate(items):
                new_batch_by_args[i].append(item)


        args = list()
        for arg in new_batch_by_args:
            if isinstance(arg[0], list):
                # cat = [torch.cat(x, dim=0) for x in arg]
                arg_arg = [list() for _ in arg[0]]
                for item_arg in arg:
                    for i, e in enumerate(item_arg):
                        arg_arg[i].append(e)
                cat = [torch.cat(x, dim=0) for x in arg_arg]
            else:
                cat = torch.cat(arg, dim=0)
            args.append(cat)
        args[4] = tensor_weight


        # for weight figure drawing
        if self.params['dataset'] in ["BiasedMNIST"]:
            sen_weight = dict()
            sen_weight[0] = args[4][args[5]==args[1]].cpu().detach().numpy()
            sen_weight[1] = args[4][args[5]!=args[1]].cpu().detach().numpy()
        else:
            sen_labels = torch.unique(args[5])
            sen_weight = {sen.item():None for sen in sen_labels}
            for k in sen_weight:
                sen_weight[k] = args[4][args[5]==k].cpu().detach().numpy()
        self.weight_for_task.append(sen_weight)
        draw_figs(sen_weight, self.params['output_dir'], drop_threshold, \
                  min(self.params['per_task_examples'], len(self.benchmark.trains[task_id])), \
                  task_id, epoch)
        
        # drop the samples below the threshold
        for i, e in enumerate(args):
            if isinstance(e, list):
                for j in range(len(e)):
                    args[i][j] = e[j][selected_idx]
            else:
                args[i] = e[selected_idx]
        dataset  = WeightModifiedDataset(args)


        train_loader = DataLoader(dataset, self.params['batch_size_train'], True, num_workers=num_workers,
                                  pin_memory=True)
        prepare_end_time = time.time()
        etc_prepare_time = prepare_end_time - solver_calculation_time + grad_start_time - prepare_start_time
        print(f"Elapsed time(etc):{np.round(etc_prepare_time, 3)}")
        self.etc_prepare_time[self.current_task][epoch] = etc_prepare_time

        if self.debug:
            print("temporal training...")
            criterion = torch.nn.CrossEntropyLoss()
            device = self.params['device']
            dummy_backbone = copy.deepcopy(self.backbone).to(device)
            dummy_loader = copy.deepcopy(train_loader)
            optimizer = torch.optim.SGD(dummy_backbone.parameters(), lr=self.params['learning_rate'])
            dummy_loader.shuffle=False
            dummy_backbone.train()
            for batch_idx, items in enumerate(dummy_loader):
                item_to_devices = [item.to(device) if isinstance(item, torch.Tensor) else item for item in items]
                inp, targ, task_ids, _, sample_weight, *_ = item_to_devices
                if epoch in self.params.get('learning_rate_decay_epoch', []): # decay
                    for g in optimizer.param_groups:
                        g['lr'] = g['lr'] / 10
                optimizer.zero_grad()
                pred = dummy_backbone(inp, task_ids)
                criterion.reduction = "none"
                loss = criterion(pred, targ)
                criterion.reduction = "mean"
                if sample_weight is not None:
                    loss = loss*sample_weight
                loss = loss.mean()
                loss.backward()
                optimizer.step()
            # training done
            loss_group = self.measure_loss(task_id, dummy_backbone) 
            self.true_loss[self.current_task][epoch] = loss_group
            print("temporal training done")

        return train_loader

class WeightModifiedDataset(Dataset):
    def __init__(self, args):
        self.arglen = len(args)
        self.args = args

    def getarg_if_list(self, arg, idx):
        return [a[idx] for a in arg]
    
    def getarg(self, arg, idx):
        if isinstance(arg, list):
            return self.getarg_if_list(arg, idx)
        else:
            return arg[idx]

    def __len__(self):
        return len(self.args[1])
    
    def __getitem__(self, idx):
        return tuple(self.getarg(arg, idx) for arg in self.args)

def draw_figs(weight_dict, output_dir, drop_threshold, y_lim, tid, epoch):
    num_bins = 20
    bins = np.arange(0, 1+1/num_bins, 1/num_bins)
    plt.rcParams["font.family"] = "Times New Roman"
    plt.rcParams['pdf.fonttype'] = 42
    plt.rc('font', size=15)
    plt.rc('axes', labelsize=15)
    plt.rc('xtick', labelsize=15)
    plt.rc('ytick', labelsize=15)
    plt.rc('legend', fontsize=15)
    plt.rc('figure', titlesize=15)
    plt.hist([x for x in weight_dict.values()], bins, stacked=True, \
             edgecolor='black', histtype='bar', label=list(weight_dict.keys()))
    plt.xlim([0, 1])
    plt.ylim([0, y_lim])
    plt.xlabel('Weight')
    plt.ylabel('Number of samples')
    plt.legend(loc='upper center')
    # plt.axvline(drop_threshold, color='black', linestyle='dashed')
    os.makedirs(f"{output_dir}/figs", exist_ok=True)
    plt.savefig(f"{output_dir}/figs/tid_{tid}_epoch_{epoch}_weight_distribution.pdf", bbox_inches="tight")
    # plt.show()
    plt.clf()