import os

from ..utils import tight_float_str, time_str


class Setting(object):
    '''
    This setting file set the all configs in train, test, log
    and model-save-path and so on.
    '''

    def __init__(self, all_config, **kwargs):
        self.train_final_batch = True
        self.train_shuffle = True
        self.test_final_batch = True
        self.test_shuffle = False
        
        # ==> execulator
        self.phase_train = all_config.phase_train
        self.up2now_test = all_config.up2now_test
        execrators = ['phase_train', 'up2now_test']
        judges = [self.phase_train, self.up2now_test]
        self.execrator = [execrators[n] for n, i in enumerate(judges) if i]
        if all_config.init_class_number == all_config.phase:
            all_config.init_ids_per_batch = all_config.ids_per_batch
            all_config.init_ims_per_id = all_config.ims_per_id

        # ==> dataset setting
        # assert all_config.phase>=all_config.ids_per_batch, 'The ids_per_batch can not be bigger than phase!'
        self.pre_process_kwargs = dict(resize_h_w=tuple(all_config.resize_h_w),
                                       scale=kwargs['scale'],
                                       im_mean=kwargs['im_mean'],
                                       im_std=kwargs['im_std'],
                                       batch_dims='NCHW',
                                       crop_prob=all_config.crop_prob,
                                       crop_ratio=all_config.crop_ratio)
        self.train_set_kwargs = dict(part='phase_train',
                                     dataset_path=all_config.dataset_path,
                                     init_class_number=all_config.init_class_number,
                                     phase=all_config.phase,
                                     ids_per_batch=all_config.ids_per_batch,
                                     ims_per_id=all_config.ims_per_id,
                                     init_ids_per_batch=all_config.init_ids_per_batch,
                                     init_ims_per_id=all_config.init_ims_per_id,
                                     batch_size=all_config.init_ids_per_batch,
                                     final_batch=self.train_final_batch,
                                     shuffle=self.train_shuffle,
                                     rotate_prob=all_config.rotate_prob,
                                     mirror_type=all_config.mirror_type,
                                     num_prefetch_threads=all_config.threads)
        self.train_set_kwargs.update(self.pre_process_kwargs)
        self.test_set_kwargs = dict(part='up2now_test',
                                    dataset_path=all_config.dataset_path,
                                    init_class_number=all_config.init_class_number,
                                    phase=all_config.phase,
                                    batch_size=all_config.test_batch_size,
                                    final_batch=self.test_final_batch,
                                    shuffle=self.test_shuffle,
                                    rotate_prob=0.0,
                                    mirror_type='none',
                                    num_prefetch_threads=all_config.threads)
        self.test_set_kwargs.update(self.pre_process_kwargs)
        self.dataset = all_config.dataset

        # ==> GPU setting
        self.sys_device_ids = all_config.sys_device_ids

        # ==> model setting
        self.model = all_config.model
        self.model_embeding_size = all_config.model_embeding_size

        # ==> train/test process setting
        self.weight_decay = 0.0005
        # self.meta_epoch = all_config.meta_epoch
        self.resume = all_config.resume
        self.epochs_per_task = all_config.epochs_per_task
        self.backbone_t = all_config.backbone_t
        self.mini_batch = all_config.mini_batch
        self.steps_per_log = all_config.steps_per_log
        self.epochs_per_val = all_config.epochs_per_val
        # -> for updating model in per task
        self.update_scale = all_config.update_scale
        self.distill_factor = all_config.distill_factor


        # ==> lr strategies
        self.base_lr = all_config.base_lr
        self.lr_satus = all_config.lr_satus
        
        self.lr_decay_type = all_config.lr_decay_type
        self.exp_decay_at_epoch = all_config.exp_decay_at_epoch
        self.staircase_decay_at_epochs = all_config.staircase_decay_at_epochs
        self.epoch_decay_steps = all_config.epoch_decay_steps
        self.decay_factor = all_config.decay_factor if self.lr_decay_type != 'exp' else self.epochs_per_task*2        
        self.backbone_lr_kwargs = dict(
            lr_decay_type=self.lr_decay_type,
            exp_decay_at_epoch=self.exp_decay_at_epoch,
            staircase_decay_at_epochs=self.staircase_decay_at_epochs,
            epoch_decay_steps=self.epoch_decay_steps,
            decay_factor=self.decay_factor
        )


        # ==> loss setting
        self.backbone_feature_loss = all_config.backbone_feature_loss
        self.backbone_local_loss = all_config.backbone_local_loss
        # -> local loss weight
        self.backbone_local_loss_weight = all_config.backbone_local_loss_weight

        # -> arcfacec setting
        arcface_kwargs = dict(
            s=all_config.arcface_s,
            m2=all_config.arcface_m2,
            m3=all_config.arcface_m3
        )
        # -> softmax setting
        softmax_kwargs = dict(
            embedding_size=2048,
            num_class=51332
        )
        # -> klloss setting
        kl_kwargs = dict(
           softmax=True 
        )
        # -> jsloss setting
        js_kwargs = dict(
           softmax=True 
        )
        # -> msaloss setting
        msa_kwargs = dict(
            s=all_config.arcface_s,
            m2=all_config.arcface_m2,
            m3=all_config.arcface_m3
        )
        # -> celoss setting
        ce_kwargs = dict(

        )
        # -> mseloss setting
        mse_kwargs = dict(
            reduction='mean'
        )
        # -> huberloss setting
        huber_kwargs = dict(
            reduction='mean',
            delta=1.0
        )
        
        loss_kwargs = dict(
            arcface=arcface_kwargs,
            softmax_nn=softmax_kwargs,
            softmax_ce=ce_kwargs,
            KLloss=kl_kwargs,
            JSloss=js_kwargs,
            msaloss=msa_kwargs,
            CEloss=ce_kwargs,
            Huberloss=huber_kwargs,
            MSEloss=mse_kwargs
        )

        self.backbone_feature_loss_kwargs = loss_kwargs[self.backbone_feature_loss]
        self.backbone_local_loss_kwargs = dict(
            part=all_config.amm_part,
            margin=all_config.margin
        )


        # ==> file path setting

        # The root dir of all output.
        # print(self.execrator)
        self.model_file = os.path.join(
            '{}'.format(self.dataset),
            '{}'.format('Base_'+str(all_config.init_class_number)),
            '{}'.format('Phase_'+str(all_config.phase)),
            )

        if all_config.exp_dir == '':
            self.exp_dir = os.path.join(
                self.model_file,
                self.execrator[0],
                str(self.backbone_feature_loss)+'_logs'
            )
        else:
            self.exp_dir = all_config.exp_dir

        #######
        # Log #
        #######

        # If True,
        # 1) stdout and stderr will be redirected to file,
        # 2) training loss etc will be written to tensorboard,
        # 3) checkpoint will be saved

        self.stdout_file = os.path.join(
            self.exp_dir, 'stdout_{}.txt'.format(time_str()))
        self.stderr_file = os.path.join(
            self.exp_dir, 'stderr_{}.txt'.format(time_str()))


        self.current_train_weight_file = os.path.join(
            self.model_file, 'train_model_weight', str(self.backbone_feature_loss)+'.pth')
        self.old_train_weight_file = os.path.join(
            self.model_file, 'train_model_weight', str(self.backbone_feature_loss)+'_old.pth')
        self.class_order_file = os.path.join(self.model_file, 'order.json')
        self.memory_file = os.path.join(self.model_file, 'memory.npz')
        
        self.train_path = all_config.train_path if all_config.train_path != '' else self.current_train_weight_file
        self.test_path = all_config.test_path
        if self.up2now_test and not os.path.exists(self.test_path):
            print('The test path is empty, test in default path...')
            self.test_path = self.current_train_weight_file