"""
Defines the MoCo (Momentum Contrast) model architecture.

The MoCo model is a contrastive learning framework that uses a momentum-updated encoder
to build a dynamic dictionary on-the-fly, which allows for a large and consistent dictionary
compared to other contrastive learning methods.

This module defines the PyTorch model architecture for the MoCo model, including the
base encoder network, the momentum-updated encoder, and the contrastive loss computation.
"""

import torch
import argparse

from models.basic_template import TrainTask
from network import backbone_dict
from .moco_wrapper import MoCoWrapper
from utils.ops import convert_to_ddp
from models import model_dict


@model_dict.register('moco')
class MoCo(TrainTask):
    """
    The `MoCo` class is a PyTorch implementation of the Momentum Contrast (MoCo) algorithm for unsupervised visual representation learning.
    
    The `set_model` method sets up the MoCo model, including the encoder networks, the MoCo wrapper, and the optimizer. The `train` method performs a training step, computing the contrastive loss and updating the model parameters.
    
    The `build_options` method defines the command-line arguments for configuring the MoCo training, including options for the projection head, symmetric contrastive loss, momentum, queue size, and temperature.
    """

    def set_model(self):
        """
        Sets up the MoCo model and optimizer.
        
        This method initializes the encoder networks for the query and key, creates the MoCo wrapper, and sets up the optimizer. The MoCo model is then converted to a distributed data parallel (DDP) model, and references to the model and optimizer are stored in the class instance.
        
        Args:
            self (object): The instance of the class containing this method.
        
        Returns:
            None
        """

        opt = self.opt
        encoder_type, dim_in = backbone_dict[opt.encoder_name]
        encoder_q = encoder_type()
        encoder_k = encoder_type()
        moco = MoCoWrapper(encoder_q, encoder_k, in_dim=dim_in, fea_dim=opt.feat_dim,
                           mlp=opt.mlp, symmetric=opt.symmetric, m=opt.moco_momentum,
                           K=opt.queue_size, T=opt.moco_temp)
        optimizer = torch.optim.SGD(params=moco.parameters(),
                                    lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay)
        moco = convert_to_ddp(moco)
        self.logger.modules = [moco, optimizer]
        self.moco = moco
        self.optimizer = optimizer
        self.feature_extractor = moco.module.encoder_k

    @staticmethod
    def build_options():
        """
        Builds and returns an `argparse.ArgumentParser` object with the following arguments:
        
        - `--mlp`: Flag to enable the projection head for MoCo v2.
        - `--symmetric`: Flag to enable symmetric contrastive loss.
        - `--moco_momentum`: Float value for the moving average momentum, default is 0.999.
        - `--queue_size`: Integer value for the size of the memory queue, default is 65536.
        - `--moco_temp`: Float value for the temperature parameter in the contrastive loss, default is 0.07.
        """

        parser = argparse.ArgumentParser('Private arguments for training of different methods')

        parser.add_argument('--mlp', help='Projection head for moco v2', dest='mlp', action='store_true')
        parser.add_argument('--symmetric', help='Symmetric contrastive loss', dest='symmetric', action='store_true')
        parser.add_argument('--moco_momentum', type=float, default=0.999, help='Moving Average Momentum')
        parser.add_argument('--queue_size', type=int, default=65536, help='Memory queue size')
        parser.add_argument('--moco_temp', type=float, default=0.07, help='temp for contrastive loss, 0.1 for cifar10')
        return parser

    def train(self, inputs, indices, n_iter):
        """
        Trains the MoCo model using the provided inputs and updates the model parameters.
        
        Args:
            inputs (tuple): A tuple containing the query and key images.
            indices (list): A list of indices for the input images.
            n_iter (int): The current iteration number.
        
        Returns:
            None
        """

        opt = self.opt

        images, labels = inputs
        self.moco.train()

        im_q, im_k = images

        # compute loss
        loss = self.moco(im_q, im_k)

        # SGD
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.logger.msg([loss, ], n_iter)
