#!/usr/bin/env python
# pylint: disable=W0201
import sys
import argparse
import yaml
import numpy as np

# torch
import torch
import torch.nn as nn

# torchlight
import torchlight
from torchlight import str2bool
from torchlight import DictAction
from torchlight import import_class
from tensorboardX import SummaryWriter
import os
class IO():
    """
        IO Processor
    """

    def __init__(self, argv=None):

        self.load_arg(argv)
        self.init_environment()
        self.load_model()
        self.load_weights()
        self.gpu()

    def load_arg(self, argv=None):
        parser = self.get_parser()

        # load arg form config file
        p = parser.parse_args(argv)
        if p.config is not None:
            # load config file
            with open(p.config, 'r') as f:
                default_arg = yaml.load(f, Loader=yaml.FullLoader)

            # update parser from config file
            key = vars(p).keys()
            for k in default_arg.keys():
                if k not in key:
                    print('Unknown Arguments: {}'.format(k))
                    assert k in key

            parser.set_defaults(**default_arg)

        self.arg = parser.parse_args(argv)

    def init_environment(self):
        self.io = torchlight.IO(
            self.arg.work_dir,
            save_log=self.arg.save_log,
            print_log=self.arg.print_log)
        self.io.save_arg(self.arg)

        # gpu
        if self.arg.use_gpu:
            gpus = torchlight.visible_gpu(self.arg.device)
            torchlight.occupy_gpu(gpus)
            self.gpus = gpus
            self.dev = "cuda:0"
        else:
            self.dev = "cpu"

        self.writer = SummaryWriter(log_dir=os.path.join(self.arg.work_dir, "tb"))

    def load_model(self):
        self.model = self.io.load_model(self.arg.model,
                                        **(self.arg.model_args))
    def load_model_teacher(self):
        self.model_teacher = self.io.load_model(self.arg.model_teacher,
                                        **(self.arg.model_teacher_args))
        
    def load_weights(self):
        if self.arg.weights:
            if ('replace' in self.arg.model) and (not self.arg.load_poly):  # and self.arg.load_wgt_only
                self.model.model = self.io.load_weights(self.model.model, self.arg.weights,
                                                self.arg.ignore_weights)
            else:
                self.model = self.io.load_weights(self.model, self.arg.weights,
                                                self.arg.ignore_weights)

    def load_weights_teacher(self):
        if self.arg.weights:
            if ('replace' in self.arg.model_teacher) and (not self.arg.load_poly_teacher):  # and self.arg.load_wgt_only
                self.model_teacher.model = self.io.load_weights(self.model_teacher.model, self.arg.weights_teacher,
                                                self.arg.ignore_weights)
            else:
                self.model_teacher = self.io.load_weights(self.model_teacher, self.arg.weights_teacher,
                                                self.arg.ignore_weights)
    def gpu(self):
        # move modules to gpu
        self.model = self.model.to(self.dev)
        if self.model_teacher:
            self.model_teacher = self.model_teacher.to(self.dev)
        for name, value in vars(self).items():
            cls_name = str(value.__class__)
            if cls_name.find('torch.nn.modules') != -1:
                setattr(self, name, value.to(self.dev))

        # model parallel
        if self.arg.use_gpu and len(self.gpus) > 1:
            self.model = nn.DataParallel(self.model, device_ids=self.gpus)
            if self.model_teacher:
                self.model_teacher = nn.DataParallel(self.model_teacher, device_ids=self.gpus)
    def start(self):
        self.io.print_log('Parameters:\n{}\n'.format(str(vars(self.arg))))

    @staticmethod
    def get_parser(add_help=False):

        #region arguments yapf: disable
        # parameter priority: command line > config > default
        parser = argparse.ArgumentParser( add_help=add_help, description='IO Processor')

        parser.add_argument('-w', '--work_dir', default='./work_dir/tmp', help='the work folder for storing results')
        parser.add_argument('-c', '--config', default=None, help='path to the configuration file')

        # processor
        parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not')
        parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing')

        # visulize and debug
        parser.add_argument('--print_log', type=str2bool, default=True, help='print logging or not')
        parser.add_argument('--save_log', type=str2bool, default=True, help='save logging or not')

        # model
        parser.add_argument('--model', default=None, help='the model will be used')
        parser.add_argument('--model_args', action=DictAction, default=dict(), help='the arguments of model')
        parser.add_argument('--weights', default=None, help='the weights for network initialization')
        parser.add_argument('--ignore_weights', type=str, default=[], nargs='+', help='the name of weights which will be ignored in the initialization')
        #endregion yapf: enable

        return parser
