import torch
import torch.nn.functional as F

from src.models.foca_base import FOCABase
from src.utils.utils import prepend_step_name, append_step_name


class FOCA(FOCABase):

    def __init__(self,
                 bundling_k: int = None,
                 pushforward_n: int = None,
                 pred_config: dict = {},
                 **kwargs):
        super().__init__(**kwargs)

        self.bundling_k = bundling_k
        self.pushforward_n = pushforward_n
        self.pred_config = pred_config

    def prepare_init_ctx(self, x, y):
        """
        x: [batch, T, dim]
        y: [batch, T, dim]
        """
        ctx = torch.zeros(x.shape[0],
                          self.ctx_dim,
                          requires_grad=True, device=x.device)  # [batch_size, c_dim]
        return ctx

    def model_forward(self, x, ctx, *args, **kwargs):
        if kwargs.get('model') is None:
            model = self.model_target if self.ctx_opt_config['use_target'] else self.model
        else:
            model = kwargs.get('model')
        return model(ctx=ctx, us=x)

    def training_step(self, batch, batch_id):
        self.update_target()
        prev, cur, target, pf_target, *_unused = batch
        ctx, ctx_info = self.infer_context(prev, cur)

        # 1-step prediction
        pred = self.model(ctx=ctx, us=cur)
        one_step_loss = F.mse_loss(pred, target)

        # Pushforward loss
        with torch.no_grad():
            for _ in range(self.pushforward_n - 1):
                cur = self.model(ctx, cur)
        pf_pred = self.model(ctx=ctx, us=cur)
        pf_loss = F.mse_loss(pf_pred, pf_target)
        loss = one_step_loss + pf_loss

        metrics = {'total_loss': loss,
                   'one_step_loss': one_step_loss,
                   'pf_loss': pf_loss}
        metrics.update(ctx_info)
        metrics = prepend_step_name(metrics, 'train')
        self.log_dict(metrics)
        return loss

    def rollout(self, x, ctx, rollout_n):
        preds = []
        pred = x
        for i in range(rollout_n):
            pred = self.model(ctx, pred)
            if self.pred_config.get('min') is not None or self.pred_config.get('max') is not None:
                pred = pred.clamp(min=self.pred_config.get('min'), max=self.pred_config.get('max'))
            preds.append(pred)
        preds = torch.cat(preds, dim=1)
        return preds

    def validation_step(self, batch, batch_id, step_name='val'):
        prev, cur, target, *unused = batch
        assert target.shape[1] % self.bundling_k == 0

        ctx, ctx_info = self.infer_context(prev, cur)
        pred = self.rollout(ctx=ctx, x=cur, rollout_n=target.shape[1] // self.bundling_k)
        loss = F.mse_loss(pred, target, reduction='none')  # [batch x time x obs]

        shor_term_loss = loss[:, :self.bundling_k, :].mean()
        long_term_loss = loss[:, self.bundling_k:, :].mean()
        rollout_loss = loss.mean()

        metrics = {'loss': rollout_loss,
                   'long_term_loss': long_term_loss,
                   'short_term_loss': shor_term_loss}
        metrics.update(ctx_info)
        metrics = append_step_name(metrics, step_name, '_')
        self.log_dict(metrics)
        return loss

    def test_step(self, *args, **kwargs):
        return self.validation_step(*args, **kwargs, step_name='test')


class FOCA_OOD(FOCA):
    """Base model with OOD validation"""
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def validation_step(self, batch, batch_id, step_name='val'):
        prev, cur, target, *_unused, ood = batch
        assert target.shape[1] % self.bundling_k == 0
        ctx, ctx_info = self.infer_context(prev, cur)
        pred = self.rollout(ctx=ctx, x=cur, rollout_n=target.shape[1] // self.bundling_k)
        print(pred.shape, target.shape)

        loss = F.mse_loss(pred, target, reduction='none')  # [batch x time x obs]
        short_term_loss = loss[:, :self.bundling_k, ...]
        long_term_loss = loss[:, self.bundling_k:, ...]
        
        # Log id and ood metrics. Don't log if tensors would have 0 dim
        if len(loss[~ood]) > 0:
            id_metrics = {'loss': loss[~ood].mean(),
                        'long_term_loss': long_term_loss[~ood].mean(),
                        'short_term_loss': short_term_loss[~ood].mean()}
            id_metrics = prepend_step_name(id_metrics, 'id')
            id_metrics = prepend_step_name(id_metrics, step_name, separator='_')
            self.log_dict(id_metrics)
        if len(loss[ood]) > 0:
            ood_metrics = {'loss': loss[ood].mean(),
                        'long_term_loss': long_term_loss[ood].mean(),
                        'short_term_loss': short_term_loss[ood].mean()}
            ood_metrics = prepend_step_name(ood_metrics, 'ood')
            ood_metrics = prepend_step_name(ood_metrics, step_name, separator='_')
            self.log_dict(ood_metrics)
        return loss.mean()

    def test_step(self, *args, **kwargs):
        return self.validation_step(*args, **kwargs, step_name='test')
