
from .distiller_utils import *
from .distiller_basic import BasicDistiller


class GeneralDistiller(BasicDistiller):
    """
    Supports intermediate features matching. **Recommended for single-teacher single-task distillation**.

    Args:
        train_config (:class:`TrainingConfig`): training configuration.
        distill_config (:class:`DistillationConfig`): distillation configuration.
        model_T (:class:`torch.nn.Module`): teacher model.
        model_S (:class:`torch.nn.Module`): student model.
        adaptor_T (Callable): teacher model's adaptor.
        adaptor_S (Callable): student model's adaptor.
        custom_matches (list): supports more flexible user-defined matches (testing).

    The roles of `adaptor_T` and `adaptor_S` are explained in :py:func:`adaptor`.

    """
    def __init__(self, train_config,
                 distill_config,
                 model_T,
                 model_S,
                 adaptor_T,
                 adaptor_S,
                 custom_matches: Optional[List[CustomMatch]] = None):
        # custom_matches=[{'module_T': module_T, 'module_S':module_S,
        #                 'loss': loss, 'weight': weight},...]
        super(GeneralDistiller, self).__init__(train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S)

        self.projs = []
        self.projs_group = []
        for im in self.d_config.intermediate_matches:
            if im.proj is not None:
                projection = im.proj[0]
                dim_in = im.proj[1]
                dim_out = im.proj[2]
                self.projs_group.append(im.proj[3])
                self.projs.append(PROJ_MAP[projection](dim_in,dim_out))
                self.projs[-1].to(self.t_config.device)
            else:
                self.projs.append(None)
                self.projs_group.append(None)

        self.has_custom_matches = False
        if custom_matches:
            self.handles_T = []
            self.handles_S = []
            self.custom_matches_cache = {'hook_outputs_T': [], 'hook_outputs_S': [], 'match_proj_funcs': [],
                                         'match_weights':  [], 'match_losses':   [], 'match_proj_groups': []}
            for match in custom_matches:
                self.add_match(match)
            self.has_custom_matches = True
        
        self.d_config.is_caching_logits = False


    def save_and_callback(self,global_step, step, epoch, callback):
        if self.has_custom_matches:
            handles_T = self.model_T._forward_hooks
            handles_S = self.model_S._forward_hooks
            self.model_S._forward_hooks = OrderedDict()  # clear hooks
            self.model_T._forward_hooks = OrderedDict()

        super(GeneralDistiller, self).save_and_callback(global_step, step, epoch, callback)

        if self.has_custom_matches:
            self.model_S._forward_hooks = handles_S  # restore hooks
            self.model_T._forward_hooks = handles_T


    def train_on_batch(self, batch, args):

        (teacher_batch, results_T), (student_batch, results_S) = get_outputs_from_batch(batch, self.t_config.device, self.model_T, self.model_S, args)

        results_T = post_adaptor(self.adaptor_T(teacher_batch,results_T))
        results_S = post_adaptor(self.adaptor_S(student_batch,results_S))

        total_loss, losses_dict = self.compute_loss(results_S, results_T)

        return total_loss, losses_dict


    def compute_loss(self,results_S,results_T):

        losses_dict = dict()

        total_loss = 0

        if 'logits' in results_T and 'logits' in results_S:
            logits_list_T = results_T['logits']  # list of tensor
            logits_list_S = results_S['logits']  # list of tensor
            labels = results_S['labels']
            total_kd_loss = 0
            if 'logits_mask' in results_S:
                masks_list_S = results_S['logits_mask']
                logits_list_S = select_logits_with_mask(logits_list_S,masks_list_S)  #(mask_sum, num_of_class)
            if 'logits_mask' in results_T:
                masks_list_T = results_T['logits_mask']
                logits_list_T = select_logits_with_mask(logits_list_T,masks_list_T)  #(mask_sum, num_of_class)

            if self.d_config.probability_shift is True:
                labels_list = results_S['labels']
                for l_T, l_S, labels in zip(logits_list_T, logits_list_S, labels_list):
                    l_T = probability_shift_(l_T, labels)
                    if self.d_config.temperature_scheduler is not None:
                        temperature = self.d_config.temperature_scheduler(l_S, l_T, self.d_config.temperature)
                    else:
                        temperature = self.d_config.temperature
                    total_kd_loss += self.kd_loss(l_S, l_T, temperature)
            else:
                for l_T,l_S, label in zip(logits_list_T,logits_list_S, labels):
                    if self.d_config.temperature_scheduler is not None:
                        temperature = self.d_config.temperature_scheduler(l_S, l_T, self.d_config.temperature)
                    else:
                        temperature = self.d_config.temperature
                    # total_kd_loss = self.kd_loss(l_S, l_T, temperature)  modified by cjj
                    total_kd_loss += self.kd_loss(l_S, l_T, label, temperature)
            total_loss += total_kd_loss * self.d_config.kd_loss_weight
            losses_dict['unweighted_kd_loss'] = total_kd_loss

        inters_T = {feature: results_T.get(feature,[]) for feature in FEATURES}
        inters_S = {feature: results_S.get(feature,[]) for feature in FEATURES}
        inputs_mask_T = results_T.get('inputs_mask',None)
        inputs_mask_S = results_S.get('inputs_mask',None)
        for ith,inter_match in enumerate(self.d_config.intermediate_matches):
            layer_T = inter_match.layer_T
            layer_S = inter_match.layer_S
            feature = inter_match.feature
            loss_type = inter_match.loss
            match_weight = inter_match.weight
            match_loss = MATCH_LOSS_MAP[loss_type]

            if type(layer_S) is list and type(layer_T) is list:
                inter_S = [inters_S[feature][s] for s in layer_S]
                inter_T = [inters_T[feature][t] for t in layer_T]
                name_S = '-'.join(map(str,layer_S))
                name_T = '-'.join(map(str,layer_T))
                if self.projs[ith]:
                    #inter_T = [self.projs[ith](t) for t in inter_T]
                    inter_S = [self.projs[ith](s) for s in inter_S]
            else:
                inter_S = inters_S[feature][layer_S]
                inter_T = inters_T[feature][layer_T]
                name_S = str(layer_S)
                name_T = str(layer_T)
                if self.projs[ith]:
                    #inter_T = self.projs[ith](inter_T)
                    inter_S = self.projs[ith](inter_S)
            intermediate_loss = match_loss(inter_S, inter_T, mask=inputs_mask_S)
            total_loss += intermediate_loss * match_weight
            losses_dict[f'unweighted_{feature}_{loss_type}_{name_S}_{name_T}'] = intermediate_loss

        if self.has_custom_matches:
            for hook_T, hook_S, match_weight, match_loss, proj_func  in \
                    zip(self.custom_matches_cache['hook_outputs_T'], self.custom_matches_cache['hook_outputs_S'],
                        self.custom_matches_cache['match_weghts'], self.custom_matches_cache['match_losses'],
                        self.custom_matches_cache['match_proj_funcs']):
                if proj_func is not None:
                    hook_S = proj_func(hook_S)
                total_loss += match_weight * match_loss(hook_S,hook_T,inputs_mask_S,inputs_mask_T)
            self.custom_matches_cache['hook_outputs_T'] = []
            self.custom_matches_cache['hook_outputs_S'] = []

        if 'losses' in results_S:
            total_hl_loss = 0
            for loss in results_S['losses']:
                # in case of multi-GPU
                total_hl_loss += loss.mean() 
            total_loss += total_hl_loss * self.d_config.hard_label_weight
            losses_dict['unweighted_hard_label_loss'] = total_hl_loss
        return total_loss, losses_dict

    def add_match(self,match: CustomMatch):
        if type(match.module_T) is str or type(match.module_S) is str:
            raise NotImplementedError
        else:
            module_T   = match.module_T
            module_S   = match.module_S
            weight     = match.weight
            loss       = match.loss
            proj_func = match.proj_func
            proj_group = match.proj_group
        self.add_match_by_module(module_T,module_S,proj_func,proj_group,weight,loss)

    def add_match_by_module(self,module_T : torch.nn.Module,
                                 module_S : torch.nn.Module,
                                 proj_func, proj_group,
                                 match_weight, match_loss):

        self.handles_T = module_T.register_forward_hook(self._hook_T)
        self.handles_S = module_S.register_forward_hook(self._hook_S)
        self.custom_matches_cache['match_weights'].append(match_weight)
        self.custom_matches_cache['match_losses'].append(match_loss)
        self.custom_matches_cache['match_proj_funcs'].append(proj_func)
        if isinstance(proj_func,nn.Module):
            self.custom_matches_cache['match_proj_funcs'][-1].to(self.t_config.device)
        self.custom_matches_cache['match_proj_groups'].append(proj_group)

    def _hook_T(self,module,input, output):
        self.custom_matches_cache['hook_outputs_T'].append(output)

    def _hook_S(self, module, input, output):
        self.custom_matches_cache['hook_outputs_S'].append(output)

