import os
import torch
import numpy as np
import torch.optim as optim
# from tensorboardX import SummaryWriter
from torch.nn import DataParallel

from .dataloader_generator import create_dataloader
from .loss_function import *
from .tester import Phase_Tester
from .trainer import Phase_Trainer
from .models import FModel, CosClassifier
from .utils.set_devices import set_devices
from .utils.strategies import get_lr_kwarg


class YoooP(object):
    def __init__(self, setting):
        self.setting = setting
        self.device_ids = setting.sys_device_ids
        # TMO transform a list, TVT transform a tensor
        self.TVT, self.TMO = set_devices(self.device_ids)
        self.device_ids = list(range(len(self.device_ids)))
        # ==> init dataloader
        '''
        NOTE:
        The order file:
            phase = 5/10/20
            part = 'backbone'/'mlp',
            trained_classes = '0,12,3,5,10...'
            classes = '87,32,45,76,199,200...'
        '''
        # order_file = setting.class_order_file if setting.resume else None
        order_file = setting.class_order_file
        self.train_dataloader = create_dataloader(
            name=setting.dataset, order_file=order_file, **setting.train_set_kwargs)
        self.test_dataloader = create_dataloader(
            name=setting.dataset, order_file=order_file, **setting.test_set_kwargs)

        # ==> init models
        # init f_model
        '''
        NOTE: 
        extract_feature_model='pure' -> just for now
        embedding_size -> if train_type is 'unsupervised' then embedding_size would be useless
        '''
        backbone_model_kwargs = dict(
            name=setting.model,
            extract_feature_model='pure',
            embedding_size=setting.model_embeding_size
        )
        
        self.backbone_model = FModel(**backbone_model_kwargs)
        # ==> init loss functions
        self.backbone_feature_loss = get_loss_func(setting.backbone_feature_loss, **setting.backbone_feature_loss_kwargs)
        self.backbone_local_loss = get_loss_func(setting.backbone_local_loss, **setting.backbone_local_loss_kwargs)
        self.backbone_local_loss_weight = setting.backbone_local_loss_weight
        # ==> get parameters
        # for ImageNet
        backbone_base_parameters = list(self.backbone_model.f_model.base.parameters())
        backbone_new_parameters = [p for n, p in self.backbone_model.named_parameters() if 'base.' not in n]
        f_loss_parameters = list(self.backbone_feature_loss.parameters())
        l_loss_parameters = list(self.backbone_local_loss.parameters())
        self.backbone_loss_parameters = f_loss_parameters + l_loss_parameters
        # ==> make optimaizers
        lr = setting.base_lr

        backbone_param_groups = [
            {'params':backbone_base_parameters, 'lr':lr},
            {'params':backbone_new_parameters, 'lr':lr*setting.lr_satus}]

        if len(self.backbone_loss_parameters) > 0:
            backbone_param_groups += [{'params':self.backbone_loss_parameters, 'lr':lr*setting.lr_satus}]

        # -> the weight decay is in setting.py, default value = 0.0005
        # self.optimizer = torch.optim.Adam(param_groups, weight_decay=0.0005)
        self.backbone_optimizer = torch.optim.SGD(backbone_param_groups, momentum=0.9)
        # self.backbone_optimizer = torch.optim.Adam(backbone_param_groups, weight_decay=0.0005)

        # ==> set lr strategy
        decay_at_epochs, factor = get_lr_kwarg(**setting.backbone_lr_kwargs)
        self.backbone_lr_decay_kwargs = dict(decay_at_epochs=decay_at_epochs, factor=factor)

    def __backbone_init(self):
        # ==> data_parallel model
        self.backbone_model = DataParallel(self.backbone_model, device_ids=self.device_ids)
        # ==> delievery to GPUs
        self.TMO([self.backbone_model, self.backbone_optimizer, self.backbone_feature_loss, self.backbone_local_loss])
        # ==> make modules_optimize
        # NOTE: modules_opt: {'mod':[],'opt':[]}
        self.backbone_module_optimize = {'mod':[self.backbone_model], 'opt':[self.backbone_optimizer]}
        if len(self.backbone_loss_parameters) > 0:
            self.backbone_module_optimize['mod']+=[self.backbone_feature_loss]

    
    # ==> Train and test pipeline

    def phase_train(self):
        self.__backbone_init()
        setting = self.setting
        # ==> tester kwargs
        test_args = [
            self.backbone_model,
            self.test_dataloader,
            self.TVT,
            setting.current_train_weight_file
        ]
        Tester = Phase_Tester(*test_args)
        # ==> trainer args and kwargs
        phase_trainer_args = [
            self.backbone_model,
            self.train_dataloader,
            Tester,
            self.TVT,
            self.TMO
        ]
        phase_trainer_kwargs = dict(
            device_ids=self.device_ids,
            backbone_feature_loss=self.backbone_feature_loss,
            backbone_local_loss=self.backbone_local_loss,
            backbone_local_loss_weight=self.backbone_local_loss_weight,
            backbone_optimizer=self.backbone_optimizer,
            backbone_module_optimize=self.backbone_module_optimize,
            lr_satus=setting.lr_satus,
            lr_decay_type=setting.lr_decay_type,
            lr_decay_kwargs=self.backbone_lr_decay_kwargs,
            resume=setting.resume,
            steps_per_log=setting.steps_per_log,
            epochs_per_val=setting.epochs_per_val,
            current_train_path=setting.current_train_weight_file,
            memory_file=setting.memory_file        )
        Trainer = Phase_Trainer(*phase_trainer_args, **phase_trainer_kwargs)
        # ==> meta-like train kwargs
        train_kwargs = dict(
            train_mini_batch=setting.mini_batch,
            epochs_per_task=setting.epochs_per_task,
            backbone_t=setting.backbone_t,
            update_scale=setting.update_scale,
            distill_factor=setting.distill_factor
        )
        Trainer.phase_train(**train_kwargs)
        print('-'*20+'Finished CIL Phase Training!'+'-'*20)


    @torch.no_grad()
    def phase_test(self, verbose=True):
        self.__backbone_init()
        setting = self.setting
        print(setting.test_path)
        # ==> tester kwargs
        test_args = [
            self.backbone_model,
            self.test_dataloader,
            self.TVT,
            setting.current_train_weight_file
        ]
        Tester = Phase_Tester(*test_args)
        # start testing
        memory_files = np.load(setting.memory_file)
        tmp_memory = memory_files['memory']
        tmp_initial_memo = memory_files['initial_memory']
        memory = self.TVT(torch.from_numpy(tmp_memory))
        initial_memory = self.TVT(torch.from_numpy(tmp_initial_memo))
        _indx = self.TVT(torch.arange(len(memory)))
        trained_indx = _indx[initial_memory==False]
        trained_memo = memory[initial_memory==False]
        self.test_dataloader.load_order_file()
        trained_class = self.test_dataloader.trained_classes
        phase = self.test_dataloader.phase
        task_number = len(trained_class)//phase
        assert len(trained_memo)==len(trained_class), 'The task training is not finished!'
        # init classifier
        trained_memory_num = torch.sum((initial_memory==False).int())
        init_classfier = CosClassifier(self.feature_in_dim, trained_memory_num)
        map_location = (lambda storage, loc: storage)
        old_ckpt = torch.load(setting.class_path, map_location=map_location)
        src_state_dict = old_ckpt['mod_state_dicts'][0]
        src_param = src_state_dict['module.fc.weight'].data
        init_classfier.fc.weight[:].copy_(src_param)
        classfier = DataParallel(init_classfier, device_ids=self.device_ids)
        self.TMO([classfier])
        true_number, whole_number, whole_acc = 0, 0, 0
        for task in range(task_number):
            # return feats, used_ids, labels, im_names
            feats, labels = Tester.test(task)
            labels = self.TVT(labels)
            classfier.eval()
            f_m_affi = classfier(feats)
            _, index = torch.max(f_m_affi, dim=-1)
            predict = trained_indx[index]
            _true_number = len(labels[labels==predict])
            true_number += _true_number
            whole_number += len(labels)
            task_acc = _true_number/len(labels)
            whole_acc += task_acc
            print('Use Classifier: The acc of task '+str(task+1)+' in task '+str(task_number)+' is %.4f' % (task_acc))
        avg_acc = whole_acc/(task+1)
        global_acc = true_number/whole_number
        print('In Classifier: The average acc is %.4f, the global acc is %.4f' % (avg_acc, global_acc))