import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Sequence
from collections import defaultdict
from tqdm import tqdm
import numpy as np

from torch.utils.data import DataLoader
from lib.utils import tensor2npy
from .BaseModel import BaseModel
from lib.callback import CallbackList, Callback


class PMF(BaseModel):
    expose_default_cfg = {
        'emb_dim': 64,
        'reg_user': 0.001,
        'reg_item': 0.001
    }

    def __init__(self, cfg):
        super(PMF, self).__init__(cfg=cfg)
        self.to(self.device)

    def build_cfg(self):
        self.user_count = self.data_cfg.dot_get('dt_info.user_count', require=True)
        self.item_count = self.data_cfg.dot_get('dt_info.item_count', require=True)
        self.emb_size = self.model_cfg['emb_dim']
        self.reg_user = self.model_cfg['reg_user']
        self.reg_item = self.model_cfg['reg_item']

    def build_model(self):
        self.user_emb = nn.Embedding(
            num_embeddings=self.user_count,
            embedding_dim=self.emb_size
        )
        self.item_emb = nn.Embedding(
            num_embeddings=self.item_count,
            embedding_dim=self.emb_size
        )

    def forward(self, user_idx: torch.LongTensor, item_idx: torch.LongTensor):
        assert len(user_idx.shape) == 1 and len(item_idx.shape) == 1 and user_idx.shape[0] == item_idx.shape[0]
        return torch.einsum("ij,ij->i", self.user_emb(user_idx), self.item_emb(item_idx))

    def get_loss(self, y_pd, y_gt, user_idx, item_idx):
        main_loss = F.mse_loss(y_pd, y_gt, reduction='mean')
        reg_loss = self.reg_user * torch.norm(self.user_emb(user_idx)) + \
                   self.reg_item * torch.norm(self.item_emb(item_idx))
        return main_loss, reg_loss

    def fit(self,
            train_dataset,
            val_dataset,
            callbacks: Sequence[Callback] = ()
            ):
        lr = self.train_cfg['lr']
        epoch_num = self.train_cfg['epoch_num']
        batch_size = self.train_cfg['batch_size']
        num_workers = self.train_cfg['num_workers']
        eval_batch_size = self.train_cfg['eval_batch_size']
        weight_decay = self.train_cfg['weight_decay']
        eps = self.train_cfg['eps']

        model = self.train()
        optimizer = self._get_optim(optimizer=self.train_cfg['optim'], lr=lr, weight_decay=weight_decay, eps=eps)
        self.optimizer = optimizer

        train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)
        if val_dataset is not None:
            val_loader = DataLoader(val_dataset, shuffle=False, batch_size=eval_batch_size, num_workers=num_workers)

        callback_list = CallbackList(callbacks=callbacks, model=model, logger=self.logger)
        callback_list.on_train_begin()

        for epoch in range(epoch_num):
            callback_list.on_epoch_begin(epoch + 1)
            logs = defaultdict(lambda: np.full((len(train_loader),), np.nan, dtype=np.float32))
            for batch_id, batch in enumerate(
                    tqdm(train_loader, ncols=self.environ_cfg['tqdm_ncols'], desc="[EPOCH={:03d}]".format(epoch + 1))
            ):
                batch = batch.to(self.device)
                u = batch[:, 0]
                i = batch[:, 1]
                r = batch[:, 2].float()
                r_pd = model(u, i).flatten()
                main_loss, reg_loss = model.get_loss(r_pd, r, u, i)
                loss = main_loss + reg_loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                logs['loss'][batch_id] = loss.item()
                logs['main_loss'][batch_id] = main_loss.item()
                logs['reg_loss'][batch_id] = reg_loss.item()
                logs['rmse'][batch_id] = self._get_metrics('rmse')(tensor2npy(r), tensor2npy(r_pd))

            for name in logs:
                logs[name] = np.nanmean(logs[name])

            if val_dataset is not None:
                val_metrics = self.evaluate(val_loader)
                logs.update({f"val_{metric}": val_metrics[metric] for metric in val_metrics})

            callback_list.on_epoch_end(epoch + 1, logs=logs)
            if self.share_obj_dict.get('stop_training', False):
                break

        callback_list.on_train_end()

    @torch.no_grad()
    def predict(self, u, i):
        return self(u, i)

    @torch.no_grad()
    def evaluate(self, loader):
        model = self.eval()
        pd_list = list(range(len(loader)))
        gt_list = list(range(len(loader)))
        for idx, batch in enumerate(tqdm(loader, ncols=self.environ_cfg['tqdm_ncols'], desc="[PREDICT]")):
            batch = batch.to(self.device)
            u = batch[:, 0]
            i = batch[:, 1]
            r = batch[:, 2]
            pd_list[idx] = self.predict(u, i).flatten()
            gt_list[idx] = r.flatten()
        y_pd = tensor2npy(torch.hstack(pd_list))
        y_gt = tensor2npy(torch.hstack(gt_list))
        eval_result = {
            metric: self._get_metrics(metric)(y_gt, y_pd) for metric in self.eval_cfg['metrics']
        }
        return eval_result
