from copy import deepcopy

import hydra
import torch
import torch.nn.functional as F

from src.models.foca_base import FOCABase
from src.utils import pylogger

log = pylogger.get_pylogger(__name__)


class FOCA(FOCABase):

    def __init__(self,
                 n_inf_samples: int,
                 **kwargs
                 ):
        super().__init__(**kwargs)
        self.n_inf_samples = n_inf_samples

    def initialize_model(self):
        model_config = deepcopy(self.model_config)
        model_config.input_dim = model_config.input_dim + self.ctx_dim

        log.info(f"Instantiating model <{model_config._target_}>")
        self.model = hydra.utils.instantiate(model_config, _recursive_=False)
        log.info(f"Instantiating model target <{model_config._target_}>")
        self.model_target = hydra.utils.instantiate(model_config, _recursive_=False)

    def prepare_init_ctx(self, x: torch.Tensor,
                         y: torch.Tensor,
                         *args, **kwargs):
        """
        x: [batch, n samples, x dim]
        y: [batch, n samples, y dim]
        """

        ctx = torch.zeros(x.shape[0],
                          self.ctx_dim,
                          requires_grad=True, device=x.device)  # [batch_size, c_dim]
        return ctx

    def model_forward(self, x, ctx, *args, **kwargs):
        """
        x: [batch, n samples, x dim]
        ctx: [batch, ctx dim]
        """
        ctx = ctx.unsqueeze(dim=1).expand(-1, x.shape[1], -1)  # [batch, n_sample, ctx dim]
        xc = torch.cat([x, ctx], dim=-1)
        if kwargs.get('model') is None:
            model = self.model_target if self.ctx_opt_config['use_target'] else self.model
        else:
            model = kwargs.get('model')
        return model(xc)

    def training_step(self, batch, batch_id):
        self.update_target()

        # x, y: [batch, n_samples, x/y dim]

        x, y, _, _ = batch

        perm = torch.randperm(x.shape[1])
        x = x[:, perm, ...]
        y = y[:, perm, ...]
        x_ctx, y_ctx = x[:, :self.n_inf_samples, ...], y[:, :self.n_inf_samples, ...]
        x_train, y_train = x[:, self.n_inf_samples:, ...], y[:, self.n_inf_samples:, ...]

        ctx, ctx_info = self.infer_context(x_ctx, y_ctx)  # ctx: [batch, ctx dim]
        pred = self.model_forward(x_train, ctx, model=self.model)
        loss = F.mse_loss(pred, y_train)

        metrics = {'loss': loss}
        metrics.update(ctx_info)
        metrics = self.append_step_name(metrics, 'train')
        self.log_dict(metrics)
        return loss

    def validation_step(self, batch, batch_id):
        x, y, _, _ = batch
        x_ctx, y_ctx = x[:, :self.n_inf_samples, ...], y[:, :self.n_inf_samples, ...]
        x_train, y_train = x[:, self.n_inf_samples:, ...], y[:, self.n_inf_samples:, ...]

        ctx, ctx_info = self.infer_context(x_ctx, y_ctx)  # ctx: [batch, ctx dim]
        pred = self.model_forward(x_train, ctx, model=self.model)
        loss = F.mse_loss(pred, y_train)

        metrics = {'loss': loss}
        metrics.update(ctx_info)
        metrics = self.append_step_name(metrics, 'val')
        self.log_dict(metrics)
        return loss


if __name__ == '__main__':
    FOCA(15)
