from diffusers import ConfigMixin, ModelMixin
from torch import nn

from ..modules.base_module import BaseModule
from ..utils import PatchParallelismCommManager, DistriConfig


class BaseModel(ModelMixin, ConfigMixin):
    def __init__(self, model: nn.Module, distri_config: DistriConfig):
        super(BaseModel, self).__init__()
        self.model = model
        self.distri_config = distri_config
        self.comm_manager = None

        self.buffer_list = None
        self.output_buffer = None
        self.counter = 0

        # for cuda graph
        self.static_inputs = None
        self.static_outputs = None
        self.cuda_graphs = None

    def forward(self, *args, **kwargs):
        raise NotImplementedError

    def set_counter(self, counter: int = 0):
        self.counter = counter
        for module in self.model.modules():
            if isinstance(module, BaseModule):
                module.set_counter(counter)

    def set_comm_manager(self, comm_manager: PatchParallelismCommManager):
        self.comm_manager = comm_manager
        for module in self.model.modules():
            if isinstance(module, BaseModule):
                module.set_comm_manager(comm_manager)

    def setup_cuda_graph(self, static_outputs, cuda_graphs):
        self.static_outputs = static_outputs
        self.cuda_graphs = cuda_graphs

    @property
    def config(self):
        return self.model.config

    def synchronize(self):
        if self.comm_manager is not None and self.comm_manager.handles is not None:
            for i in range(len(self.comm_manager.handles)):
                if self.comm_manager.handles[i] is not None:
                    self.comm_manager.handles[i].wait()
                    self.comm_manager.handles[i] = None
