import os
import argparse
import logging
import re
import functools
import torch
import numpy as np

from model.fe_resnet import resnet18_dropout


SEED = 98
INPUT_SHAPE = (3, 224, 224)
BATCH_SIZE = 64
TRAIN_ITERS = 100000   
DEFAULT_ITERS = 2000
TRANSFER_ITERS = DEFAULT_ITERS
PRUNE_ITERS = DEFAULT_ITERS
DISTILL_ITERS = DEFAULT_ITERS
STEAL_ITERS = DEFAULT_ITERS
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
CONTINUE_TRAIN = False  # whether to continue previous training


def lazy_property(func):
    attribute = '_lazy_' + func.__name__

    @property
    @functools.wraps(func)
    def wrapper(self):
        if not hasattr(self, attribute):
            setattr(self, attribute, func(self))
        return getattr(self, attribute)

    return wrapper


def base_args():
    args = argparse.Namespace()
    args.const_lr = False
    args.batch_size = BATCH_SIZE
    args.lr = 5e-3
    args.print_freq = 100
    args.label_smoothing = 0
    args.vgg_output_distill = False
    args.reinit = False
    args.l2sp_lmda = 0
    args.train_all = False
    args.ft_begin_module = None
    args.momentum = 0
    args.weight_decay = 1e-4
    args.beta = 1e-2
    args.feat_lmda = 0
    args.test_interval = 1000
    args.adv_test_interval = -1
    args.feat_layers = '1234'
    args.no_save = False
    args.steal = False
    return args


class ModelWrapper:
    def __init__(self, benchmark, teacher_wrapper, trans_str,
                 arch_id=None, dataset_id=None, iters=100, fc=True):
        self.logger = logging.getLogger('ModelWrapper')
        self.benchmark = benchmark
        self.teacher_wrapper = teacher_wrapper
        self.trans_str = trans_str
        self.arch_id = arch_id if arch_id else teacher_wrapper.arch_id
        self.dataset_id = dataset_id if dataset_id else teacher_wrapper.dataset_id
        self.torch_model_path = os.path.join(benchmark.models_dir, f'{self.__str__()}')
        self.iters = iters
        self.fc = fc
        assert self.arch_id is not None
        assert self.dataset_id is not None

    def __str__(self):
        teacher_str = '' if self.teacher_wrapper is None else self.teacher_wrapper.__str__()
        return f'{teacher_str}{self.trans_str}-'
    
    def name(self):
        return self.__str__()

    def torch_model_exists(self):
        ckpt_path = os.path.join(self.torch_model_path, 'final_ckpt.pth')
        return os.path.exists(ckpt_path)

    def save_torch_model(self, torch_model):
        if not os.path.exists(self.torch_model_path):
            os.makedirs(self.torch_model_path)
        ckpt_path = os.path.join(self.torch_model_path, 'final_ckpt.pth')
        torch.save(
            {'state_dict': torch_model.state_dict()},
            ckpt_path,
        )

    @lazy_property
    def torch_model(self):
        """
        load the model object from torch_model_path
        :return: torch.nn.Module object
        """
        if self.dataset_id == 'ImageNet':
            num_classes = 1000
        elif self.dataset_id == 'SDog120':
            num_classes = 120
        
        if self.fc:
            torch_model = eval(f'{self.arch_id}_dropout')(
                pretrained=False,
                num_classes=num_classes
            )
        else:
            torch_model = eval(f'fe{self.arch_id}')(
                pretrained=False,
                num_classes=num_classes
            )
        
        m = re.match(r'(\S+)\((\S*)\)', self.trans_str)
        method = m.group(1)
        params = m.group(2).split(',')
        ckpt = torch.load(os.path.join(self.torch_model_path, 'final_ckpt.pth'))
        torch_model.load_state_dict(ckpt['state_dict'])
        return torch_model
    
    @lazy_property
    def torch_model_on_device(self):
        m = re.match(r'(\S+)\((\S*)\)', self.trans_str)
        method = m.group(1)
        if method == "quantize":
            return self.torch_model.to("cpu")
        else:
            return self.torch_model.to(DEVICE)

    def load_saved_weights(self, torch_model):
        """
        load weights in the latest checkpoint to torch_model
        """
        ckpt_path = os.path.join(self.torch_model_path, 'ckpt.pth')
        if os.path.exists(ckpt_path):
            ckpt = torch.load(ckpt_path)
            torch_model.load_state_dict(ckpt['state_dict'])
            self.logger.info('load_saved_weights: loaded a previous checkpoint')
        else:
            self.logger.info('load_saved_weights: no previous checkpoint found')
        return torch_model

    @lazy_property
    def input_shape(self):
        return INPUT_SHAPE

    def get_seed_inputs(self, n, rand=False):
        if rand:
            batch_input_size = (n, *INPUT_SHAPE)
            images = np.random.normal(size=batch_input_size).astype(np.float32)
        else:
            dataset_id = 'MIT67' if self.dataset_id == 'ImageNet' else self.dataset_id
            train_loader = self.benchmark.get_dataloader(
                dataset_id, split='train', batch_size=n, shuffle=True)
            images, labels = next(iter(train_loader))
            images = images.to('cpu').numpy()
        return images

    def batch_forward(self, inputs):
        if isinstance(inputs, np.ndarray):
            inputs = torch.from_numpy(inputs)
        m = re.match(r'(\S+)\((\S*)\)', self.trans_str)
        method = m.group(1)
        if method == "quantize":
            inputs = inputs.to("cpu")
        else:
            inputs = inputs.to(DEVICE)
        self.torch_model_on_device.eval()
        with torch.no_grad():
            return self.torch_model_on_device(inputs)

    def list_tensors(self):
        pass

    def batch_forward_with_ir(self, inputs):
        if isinstance(inputs, np.ndarray):
            inputs = torch.from_numpy(inputs)
        idx = 0
        hook_handles = []
        module_ir = {}
        model = self.torch_model

        def register_hooks(module):
            def hook(module, input, output):
                global idx
                class_name = str(module.__class__).split(".")[-1].split("'")[0]
                module_name = f"{class_name}/{idx:03d}"
                idx += 1
                module_ir[module_name] = output.numpy()

            if len(list(module.children())) == 0:
                handle = module.register_forward_hook(hook)
                hook_handles.append(handle)

        def remove_hooks():
            for h in hook_handles:
                h.remove()

        model.eval()
        with torch.no_grad():
            model.apply(register_hooks)
            outputs = model(inputs)
            remove_hooks()
        return module_ir

    @lazy_property
    def accuracy(self):
        model = self.torch_model.to(DEVICE)
        test_loader = self.benchmark.get_dataloader(self.dataset_id, split='test')

        with torch.no_grad():
            model.eval()
            total = 0
            top1 = 0
            for i, (batch, label) in enumerate(test_loader):
                batch, label = batch.to(DEVICE), label.to(DEVICE)
                total += batch.size(0)
                out = model(batch)
                _, pred = out.max(dim=1)
                top1 += int(pred.eq(label).sum().item())
        # print(top1, total)
        return float(top1) / total * 100


class ImageBenchmark:
    def __init__(self, datasets_dir='data', models_dir='models/'):
        self.logger = logging.getLogger('ImageBench')
        self.datasets_dir = datasets_dir
        self.models_dir = models_dir
        self.datasets = ['SDog120']
        self.archs = ['resnet18']

    def load_pretrained(self, arch_id, fc=True):
        """
        Get the model pretrained on imagenet
        :param arch_id: the name of the arch
        :return: a ModelWrapper instance
        """
        model_wrapper = ModelWrapper(
            benchmark=self,
            teacher_wrapper=None,
            trans_str=f'pretrain({arch_id},ImageNet)',
            arch_id=arch_id,
            dataset_id='ImageNet',
            fc=fc,
        )
        return model_wrapper

    def load_trained(self, arch_id, dataset_id, iters=TRAIN_ITERS, fc=True):
        """
        Get the model with architecture arch_id trained on dataset dataset_id
        :param arch_id: the name of the arch
        :param dataset_id: the name of the dataset
        :param iters: number of iterations
        :return: a ModelWrapper instance
        """
        model_wrapper = ModelWrapper(
            benchmark=self,
            teacher_wrapper=None,
            trans_str=f'train({arch_id},{dataset_id})',
            arch_id=arch_id,
            dataset_id=dataset_id,
            iters=iters,
            fc=fc,
        )
        return model_wrapper

    def list_models(self, fc=True):
        source_models = []

        # retrain models
        retrain_models = []
        for arch_id in self.archs:
            for dataset_id in self.datasets:
                retrain_model = self.load_trained(arch_id, dataset_id, TRAIN_ITERS, fc=fc)
                retrain_models.append(retrain_model)
                yield retrain_model

