import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ._base import Distiller


def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.):
    logsoftmax = nn.LogSoftmax().cuda()
    n_classes = pred.size(1)
    # convert to one-hot
    target = torch.unsqueeze(target, 1)
    soft_target = torch.zeros_like(pred)
    soft_target.scatter_(1, target, 1)
    # label smoothing
    soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes
    return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))


class AttrDict(dict):
    IMMUTABLE = '__immutable__'

    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__[AttrDict.IMMUTABLE] = False

    def __getattr__(self, name):
        if name in self.__dict__:
            return self.__dict__[name]
        elif name in self:
            return self[name]
        else:
            raise AttributeError(name)

    def __setattr__(self, name, value):
        if not self.__dict__[AttrDict.IMMUTABLE]:
            if name in self.__dict__:
                self.__dict__[name] = value
            else:
                self[name] = value
        else:
            raise AttributeError('Attempted to set "{}" to "{}", but AttrDict is immutable'.format(name, value))

    def immutable(self, is_immutable):
        """Set immutability to is_immutable and recursively apply the setting
        to all nested AttrDicts.
        """
        self.__dict__[AttrDict.IMMUTABLE] = is_immutable
        # Recursively set immutable state
        for v in self.__dict__.values():
            if isinstance(v, AttrDict):
                v.immutable(is_immutable)
        for v in self.values():
            if isinstance(v, AttrDict):
                v.immutable(is_immutable)

    def is_immutable(self):
        return self.__dict__[AttrDict.IMMUTABLE]


__C = AttrDict()

config = __C

__C.net_type='resnet'   # choose resnet or mobilenet

__C.train_params=AttrDict()
__C.train_params.epochs = 100
__C.train_params.use_seed=False
__C.train_params.seed=0
__C.train_params.print_freq = 50

__C.optim=AttrDict()
__C.optim.init_lr=0.1
__C.optim.min_lr=0
__C.optim.momentum=0.9
__C.optim.weight_decay=1e-4
__C.optim.use_grad_clip=False
__C.optim.grad_clip=10
__C.optim.label_smooth=False
__C.optim.smooth_alpha=0.1

__C.optim.if_resume=False
__C.optim.resume_path=''

__C.data=AttrDict()
__C.data.data_path = 'PATH/to/DataSet'
__C.data.num_workers=32
__C.data.batch_size=256
__C.data.dataset='imagenet'
__C.data.train_data_type='lmdb'  # set 'img' to read original images
__C.data.val_data_type='lmdb'  # set 'img' to read original images
__C.data.patch_dataset=False
__C.data.num_examples=1281167
__C.data.input_size=(3,224,224)
__C.data.type_of_data_aug='random_sized'  # random_sized / rand_scale
__C.data.random_sized=AttrDict()
__C.data.random_sized.min_scale=0.08
__C.data.mean=[0.485, 0.456, 0.406]
__C.data.std=[0.229, 0.224, 0.225]
__C.data.color=False

__C.optim.cosine=AttrDict()
__C.optim.cosine.use_restart=False
__C.optim.cosine.restart=AttrDict()
__C.optim.cosine.restart.lr_period=[10, 20, 40, 80, 160, 320]
__C.optim.cosine.restart.lr_step=[0, 10, 30, 70, 150, 310]


class WSLD(Distiller):
    def __init__(self, student, teacher, cfg, wrap_student_in_ddp=False, local_rank=None):
        super(WSLD, self).__init__(student, teacher, wrap_student_in_ddp=wrap_student_in_ddp, local_rank=local_rank)
        # self.T = 2
        # self.alpha = 2.5

        self.T = cfg.WSLD.TEMPERATURE
        self.alpha = cfg.WSLD.LOSS.ALPHA

        self.softmax = nn.Softmax(dim=1).cuda()
        self.logsoftmax = nn.LogSoftmax().cuda()

        if config.optim.label_smooth:
            self.hard_loss = cross_entropy_with_label_smoothing
        else:
            self.hard_loss = nn.CrossEntropyLoss()
            self.hard_loss = self.hard_loss.cuda()

    def forward_train(self, image, perturbedInput, target, **kwargs):
        logits_student, _ = self.student(image)
        with torch.no_grad():
            logits_teacher, _ = self.teacher(perturbedInput)
        
        s_input_for_softmax = logits_student / self.T
        t_input_for_softmax = logits_teacher / self.T

        t_soft_label = self.softmax(t_input_for_softmax)

        softmax_loss = - torch.sum(t_soft_label * self.logsoftmax(s_input_for_softmax), 1, keepdim=True)

        fc_s_auto = logits_student.detach()
        fc_t_auto = logits_teacher.detach()
        log_softmax_s = self.logsoftmax(fc_s_auto)
        log_softmax_t = self.logsoftmax(fc_t_auto)
        one_hot_label = F.one_hot(target, num_classes=1000).float()
        softmax_loss_s = - torch.sum(one_hot_label * log_softmax_s, 1, keepdim=True)
        softmax_loss_t = - torch.sum(one_hot_label * log_softmax_t, 1, keepdim=True)

        focal_weight = softmax_loss_s / (softmax_loss_t + 1e-7)
        ratio_lower = torch.zeros(1).cuda()
        focal_weight = torch.max(focal_weight, ratio_lower)
        focal_weight = 1 - torch.exp(- focal_weight)
        softmax_loss = focal_weight * softmax_loss

        soft_loss = (self.T ** 2) * torch.mean(softmax_loss)
        hard_loss = self.hard_loss(logits_student, target) # celoss
        loss_kd = self.alpha * soft_loss
        
        losses_dict = {"loss_ce": hard_loss, "loss_kd": loss_kd,}
        
        return logits_student, logits_teacher, losses_dict