import hydra
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from src.utils import pylogger

log = pylogger.get_pylogger(__name__)


class FOCABase(pl.LightningModule):

    def __init__(self,
                 model_config: dict = None,
                 opt_config: dict = None,
                 ctx_opt_config: dict = None,
                 ctx_dim: int = None,
                 tau: float = 0.1,
                 **kwargs):

        super().__init__()
        self.save_hyperparameters()

        self.model_config = model_config

        if opt_config is None:
            opt_config = {'lr': 1e-3, 'T_0': 32}
        self.opt_config = opt_config

        if ctx_opt_config is None:
            ctx_opt_config = {'lr': 1e-3,
                              'T0': 32,
                              'n_iter': 50,
                              'use_target': True,  # whether to use "target" for ctx inference.
                              'return_best': False,  # return the "best" ctx at the inferred ctx
                              'detach': True,
                              'ctx_lambda': 0.0001}  # whether to detach the inferred ctx from the computational chain.
        self.ctx_opt_config = ctx_opt_config

        self.ctx_dim = ctx_dim
        self.tau = tau

        self.model: nn.Module = None  # should be initialized in "initialize_model"
        self.model_target: nn.Module = None  # should be initialized in "initialize_model"
        self.initialize_model()
        # self.model_target.load_state_dict(self.model.state_dict())

    def initialize_model(self):
        log.info(f"Instantiating model <{self.model_config._target_}>")
        self.model = hydra.utils.instantiate(self.model_config, _recursive_=False)
        log.info(f"Instantiating target model <{self.model_config._target_}>")
        self.model_target = hydra.utils.instantiate(self.model_config, _recursive_=False)
        self.model_target.load_state_dict(self.model.state_dict())

    def infer_context(self, ctx_x, ctx_y, *args, **kwargs):
        with torch.set_grad_enabled(True):
            ctx = self.prepare_init_ctx(ctx_x, ctx_y, *args, **kwargs)
            opt = torch.optim.Adam([ctx], lr=self.ctx_opt_config['lr'])
            scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt,
                                                                             T_0=self.ctx_opt_config['T_0'])

            best_ctx = None
            best_loss = float('inf')
            best_inner_c_loss, best_penalty, best_c_norm, best_penalty_weight = None, None, None, None

            for i in range(self.ctx_opt_config['n_iter']):
                ctx_pred = self.model_forward(ctx_x, ctx, *args, **kwargs)
                loss = F.mse_loss(ctx_pred, ctx_y)
                # ctx regularization
                if self.ctx_opt_config['ctx_lambda'] > 0:
                    penalty_weight = self.ctx_opt_config['ctx_lambda']
                    penalty = self.ctx_opt_config['ctx_lambda'] * ctx.norm()
                    total_loss = loss + penalty
                else:
                    penalty_weight = 0.0
                    penalty = 0.0
                    total_loss = loss

                # check the best c
                if total_loss <= best_loss:
                    best_loss = total_loss.item()
                    best_ctx = ctx.detach().clone()
                    best_inner_ctx_loss = loss.item()
                    best_ctx_norm = best_ctx.norm()
                    best_penalty = penalty
                    best_penalty_weight = penalty_weight

                opt.zero_grad()
                total_loss.backward()
                opt.step()
                scheduler.step()

            info = {
                'inner_tot_loss': best_loss,
                'inner_ctx_loss': best_inner_ctx_loss,
                'inner_penalty_loss': best_penalty,
                'inner_ctx_norm': best_ctx_norm,
                'inner_penalty_weight': best_penalty_weight
            }

            if self.ctx_opt_config['return_best']:
                ret_ctx = best_ctx
            else:
                ret_ctx = ctx

            ret = (ret_ctx.detach() if self.ctx_opt_config['detach'] else ret_ctx, info)
            return ret

    @staticmethod
    def append_step_name(log_dict, step_name):
        log_dict_ = {}
        for k, v in log_dict.items():
            log_dict_['{}_{}'.format(step_name, k)] = v
        return log_dict_

    def configure_optimizers(self):
        opt = Adam(self.model.parameters(), lr=self.opt_config['lr'])
        ret = {
            'optimizer': opt,
            'lr_scheduler': {
                'scheduler': CosineAnnealingWarmRestarts(opt, T_0=self.opt_config['T_0']),
                'interval': 'step'
            }
        }
        return ret

    def update_target(self):
        for param_target, param in zip(self.model_target.parameters(), self.model.parameters()):
            param_target.data.copy_(param_target.data * (1.0 - self.tau) + param.data * self.tau)

    def prep_init_ctx(self, x, y, *args, **kwargs):
        raise NotImplementedError

    def model_forward(self, x, ctx, *args, **kwargs):
        raise NotImplementedError
