from methods.base import BaseAdaptation
from methods import register_method
from utils.utils import inputs_to_device
import torch
from layers import disp_to_depth, compute_depth_errors_adadepth, compute_depth_errors
from networks import get_supervised_models, get_self_supervised_models
from utils.losses import DepthLoss
from layers import update_ema_variables
import torch.nn.functional as F
import copy
import torch.nn as nn


@register_method(name='mic')
class MIC(BaseAdaptation):
    def __init__(self, opt, **kwargs):
        super().__init__(opt, **kwargs)
        if self.opt.model_type == "supervised":
            self.models = get_supervised_models(self.opt)
        else:
            raise NotImplementedError("Self-supervised models not implemented yet")

        for m in self.models.values():
            m.eval()

                # create ema models
        self.models_ema = copy.deepcopy(self.models)
        for model in self.models_ema.values():
            model.eval()
            for param in model.parameters():
                param.requires_grad = False
                param.detach_()

        head_dim = 192 # embed_dim of swin transformer
        self.patch_size = 4
        self.mask_ratio = 0.5

        # mask_token
        mask_token_dim = (1, 1, head_dim)
        self.mask_token = torch.zeros(*mask_token_dim)
        self.mask_token.to(self.opt.device)     

        parameters_to_train = []
        parameters_to_train += list(self.models["encoder"].parameters())
        parameters_to_train += list(self.models["depth"].parameters())
        self.optimizer = torch.optim.Adam(parameters_to_train, self.opt.learning_rate)
 
        self.loss = DepthLoss(loss='l1')

    def process_batch(self, inputs):
        inputs_to_device(inputs, self.opt.device)
        
        input_img = (inputs["color_uncrop", 0, 0]-self.opt.mean)/self.opt.std
        B, _, H, W = input_img.shape
        orig_patch_nums = (H//self.patch_size, W//self.patch_size)
        
        with torch.no_grad():
            ref_features = self.models_ema["encoder"](input_img)
            pred_depth = self.models_ema["depth"](ref_features)

        mask_chosed = (torch.rand(orig_patch_nums).flatten().unsqueeze(0) < self.mask_ratio).to(torch.float32)  

        feats = self.models["encoder"](input_img, self.mask_token, mask_chosed)
        outputs = self.models["depth"](feats)
    
        loss = self.loss(pred_depth, outputs).mean()

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        for model, model_ema in zip(self.models.values(), self.models_ema.values()):
            update_ema_variables(model, model_ema, 0.999)

        with torch.no_grad():
            feats = self.models_ema["encoder"](input_img)
            depth_out = self.models_ema["depth"](feats)

        error_local = list(compute_depth_errors_adadepth(self.opt, inputs['depth_gt_uncrop'], depth_out, median_scaling=True))
        error = list(compute_depth_errors_adadepth(self.opt, inputs['depth_gt_uncrop'], depth_out, median_scaling=False))

        for idx, term in enumerate(error):
            error[idx] = term.detach().cpu().numpy()

        for idx, term in enumerate(error_local):
            error_local[idx] = term.detach().cpu().numpy()

        outputs = {}
        outputs['depth'] = depth_out
        outputs['depth_model_pred'] = pred_depth
        losses = {}
        
        metrics = {
            'error': error,
            'error_local': error_local,
        }
        
        return outputs, metrics, losses
