import torch
import torch.nn as nn
import torch.nn.functional as F


class Distiller(nn.Module):
    def __init__(self, student, teacher, wrap_student_in_ddp=False, local_rank=None):
        super(Distiller, self).__init__()
        # self.student = student
        # self.teacher = teacher

        if wrap_student_in_ddp:
            if local_rank is None:
                raise ValueError("local_rank must be specified when using DDP")
            self.student = torch.nn.parallel.DistributedDataParallel(student.cuda(local_rank), device_ids=[local_rank])
            self.teacher = torch.nn.parallel.DataParallel(teacher.cuda(local_rank), device_ids=[local_rank])
        else:
            self.student = student.cuda(local_rank) if local_rank is not None else student
            self.teacher = teacher.cuda(local_rank) if local_rank is not None else teacher


    def train(self, mode=True):
        # teacher as eval mode by default
        if not isinstance(mode, bool):
            raise ValueError("training mode is expected to be boolean")
        self.training = mode
        for module in self.children():
            module.train(mode)
        self.teacher.eval()
        # self.teacher.train()
        return self


    def set_teacher_train(self):
        self.teacher.train()


    def get_learnable_parameters(self):
        # if the method introduces extra parameters, re-impl this function
        return [v for k, v in self.student.named_parameters()]


    def get_extra_parameters(self):
        # calculate the extra parameters introduced by the distiller
        return 0


    def forward_train(self, **kwargs):
        # training function for the distillation method
        raise NotImplementedError()


    def forward_test(self, image):
        return self.student(image)[0]


    def forward(self, **kwargs):
        if self.training:
            return self.forward_train(**kwargs)
        return self.forward_test(kwargs["image"])


class Vanilla(nn.Module):
    def __init__(self, student):
        super(Vanilla, self).__init__()
        self.student = student

    def get_learnable_parameters(self):
        return [v for k, v in self.student.named_parameters()]

    def forward_train(self, image, target, **kwargs):
        logits_student, _ = self.student(image)
        loss = F.cross_entropy(logits_student, target)
        return logits_student, {"ce": loss}

    def forward(self, **kwargs):
        if self.training:
            return self.forward_train(**kwargs)
        return self.forward_test(kwargs["image"])

    def forward_test(self, image):
        return self.student(image)[0]
