# Copyright (c) Facebook, Inc. and its affiliates.

import tqdm
import torch
from torch import nn
from torch import optim

from models import TKBCModel
from regularizers import Regularizer
from datasets import TemporalDataset
from utils import logger, args


class MyOptimizer(object):
    def __init__(
            self, model: TKBCModel,
            emb_regularizer: Regularizer,
            optimizer: optim.Optimizer,
            time_loss,
            batch_size: int = 256,
            verbose: bool = True,
    ):
        self.model = model
        self.emb_regularizer = emb_regularizer
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.verbose = verbose
        self.time_loss = time_loss

    def epoch(self, examples: torch.LongTensor, time_range=None):
        sum_s, sum_o, sum_score, sum_t, sum_pos, sum_reg, cur_size = 0, 0, 0, 0, 0, 0, 0
        with torch.autograd.detect_anomaly():
            actual_examples = examples[torch.randperm(examples.shape[0]), :]
            ent_loss = nn.CrossEntropyLoss(reduction='mean')
            # ent_loss = nn.NLLLoss(reduction='mean')
            tim_loss = self.time_loss
            # tim_loss = nn.CrossEntropyLoss(reduction='mean')
            score_loss = nn.NLLLoss(reduction='mean')
            emb_reg = self.emb_regularizer
            with tqdm.tqdm(total=examples.shape[0], unit='ex', disable=not self.verbose) as bar:
                bar.set_description(f'train loss')
                b_begin = 0
                while b_begin < examples.shape[0]:
                    input_batch = actual_examples[b_begin:b_begin + self.batch_size].cuda()

                    l_s, l_o, l_score, l_t, l_sr, l_or, l_reg = 0, 0, 0, 0, 0, 0, 0

                    if args.objective_mode == 'classification': # Classification
                        s_pred, o_pred, t_pred, l_pos, or_pred, factors = self.model.forward(input_batch)

                        if s_pred is not None:
                            l_s = ent_loss(s_pred, input_batch[:, 0])
                        if o_pred is not None:
                            l_o = ent_loss(o_pred, input_batch[:, 2])
                        if t_pred is not None:
                            l_t = tim_loss(t_pred, input_batch[:, 3])
                        if factors is not None:
                            l_reg = emb_reg(factors)

                        l = l_s + l_o + l_t + l_reg
                    else: # Regression
                        score_pred, factors = self.model.forward_(input_batch, time_range)
                        
                        if score_pred is not None:
                            # l_score = score_loss(score_pred, input_batch[:, 4])
                            l_score = score_pred.sum()

                        if factors is not None:
                            l_reg = emb_reg(factors)
                        l = l_score + l_reg
                        l_pos = 0

                    self.optimizer.zero_grad()
                    l.backward()
                    nn.utils.clip_grad_norm_(self.model.parameters(), 10)
                    self.optimizer.step()

                    b_begin += self.batch_size
                    sum_s += l_s * input_batch.shape[0]
                    sum_o += l_o * input_batch.shape[0]
                    sum_score += l_score * input_batch.shape[0]
                    sum_t += l_t * input_batch.shape[0]
                    sum_reg += l_reg * input_batch.shape[0]
                    sum_pos += l_pos * input_batch.shape[0]

                    cur_size += input_batch.shape[0]

                    bar.update(input_batch.shape[0])
                    bar.set_postfix(
                        l_s=f'{float(sum_s) / cur_size:.3f}',
                        l_score=f'{float(sum_score) / cur_size:.5f}',
                        l_o=f'{float(sum_o) / cur_size:.5f}',
                        l_t=f'{float(sum_t) / cur_size:.3f}',
                        l_pos=f'{float(sum_pos) / cur_size:.3f}',
                        reg=f'{float(sum_reg) / cur_size:.5f}'
                    )
            if cur_size > 0:
                logger.info(f'l_s: {float(sum_s) / cur_size:.3f}, l_score: {float(sum_score) / cur_size:.5f}, l_o: {float(sum_o) / cur_size:.5f}, l_t: {float(sum_t) / cur_size:.3f}, reg: {float(sum_reg) / cur_size:.5f}')
        return float(sum_score) / cur_size


class TKBCOptimizer(object):
    def __init__(
            self, model: TKBCModel,
            emb_regularizer: Regularizer, temporal_regularizer: Regularizer,
            optimizer: optim.Optimizer, batch_size: int = 256,
            verbose: bool = True
    ):
        self.model = model
        self.emb_regularizer = emb_regularizer
        self.temporal_regularizer = temporal_regularizer
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.verbose = verbose

    def epoch(self, examples: torch.LongTensor):
        actual_examples = examples[torch.randperm(examples.shape[0]), :]
        loss = nn.CrossEntropyLoss(reduction='mean')
        with tqdm.tqdm(total=examples.shape[0], unit='ex', disable=not self.verbose) as bar:
            bar.set_description(f'train loss')
            b_begin = 0
            while b_begin < examples.shape[0]:
                input_batch = actual_examples[
                    b_begin:b_begin + self.batch_size
                ].cuda()
                predictions, factors, time = self.model.forward(input_batch)
                truth = input_batch[:, 2]

                l_fit = loss(predictions, truth)
                l_reg = self.emb_regularizer.forward(factors)
                l_time = torch.zeros_like(l_reg)
                if time is not None:
                    l_time = self.temporal_regularizer.forward(time)
                l = l_fit + l_reg + l_time

                self.optimizer.zero_grad()
                l.backward()
                self.optimizer.step()
                b_begin += self.batch_size
                bar.update(input_batch.shape[0])
                bar.set_postfix(
                    loss=f'{l_fit.item():.3f}',
                    reg=f'{l_reg.item():.3f}',
                    cont=f'{l_time.item():.3f}'
                )


class IKBCOptimizer(object):
    def __init__(
            self, model: TKBCModel,
            emb_regularizer: Regularizer, temporal_regularizer: Regularizer,
            optimizer: optim.Optimizer, dataset: TemporalDataset, batch_size: int = 256,
            verbose: bool = True
    ):
        self.model = model
        self.dataset = dataset
        self.emb_regularizer = emb_regularizer
        self.temporal_regularizer = temporal_regularizer
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.verbose = verbose

    def epoch(self, examples: torch.LongTensor):
        actual_examples = examples[torch.randperm(examples.shape[0]), :]
        loss = nn.CrossEntropyLoss(reduction='mean')
        with tqdm.tqdm(total=examples.shape[0], unit='ex', disable=not self.verbose) as bar:
            bar.set_description(f'train loss')
            b_begin = 0
            while b_begin < examples.shape[0]:
                time_range = actual_examples[b_begin:b_begin + self.batch_size].cuda()

                ## RHS Prediction loss
                sampled_time = (
                        torch.rand(time_range.shape[0]).cuda() * (time_range[:, 4] - time_range[:, 3]).float() +
                        time_range[:, 3].float()
                ).round().long()
                with_time = torch.cat((time_range[:, 0:3], sampled_time.unsqueeze(1)), 1)

                predictions, factors, time = self.model.forward(with_time)
                truth = with_time[:, 2]

                l_fit = loss(predictions, truth)

                ## Time prediction loss (ie cross entropy over time)
                time_loss = 0.
                if self.model.has_time():
                    filtering = ~(
                        (time_range[:, 3] == 0) *
                        (time_range[:, 4] == (self.dataset.n_timestamps - 1))
                    ) # NOT no begin and no end
                    these_examples = time_range[filtering, :]
                    truth = (
                            torch.rand(these_examples.shape[0]).cuda() * (these_examples[:, 4] - these_examples[:, 3]).float() +
                            these_examples[:, 3].float()
                    ).round().long()
                    time_predictions = self.model.forward_over_time(these_examples[:, :3].cuda().long())
                    time_loss = loss(time_predictions, truth.cuda())

                l_reg = self.emb_regularizer.forward(factors)
                l_time = torch.zeros_like(l_reg)
                if time is not None:
                    l_time = self.temporal_regularizer.forward(time)
                l = l_fit + l_reg + l_time + time_loss

                self.optimizer.zero_grad()
                l.backward()
                self.optimizer.step()
                b_begin += self.batch_size
                bar.update(with_time.shape[0])
                bar.set_postfix(
                    loss=f'{l_fit.item():.0f}',
                    loss_time=f'{time_loss if type(time_loss) == float else time_loss.item() :.0f}',
                    reg=f'{l_reg.item():.0f}',
                    cont=f'{l_time.item():.4f}'
                )


class DEOptimizer(object):
    def __init__(
            self, model: TKBCModel,
            optimizer: optim.Optimizer, batch_size: int = 256,
            verbose: bool = True
    ):
        self.model = model
        self.neg_ratio = model.neg_ratio
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.verbose = verbose
    
    def epoch(self, examples: torch.LongTensor):
        actual_examples = examples[torch.randperm(examples.shape[0]), :]
        loss_f = nn.CrossEntropyLoss()
        with tqdm.tqdm(total=examples.shape[0], unit='ex', disable=not self.verbose) as bar:
            bar.set_description(f'train loss')
            b_begin = 0
            while b_begin < examples.shape[0]:
                input_batch = actual_examples[
                    b_begin:b_begin + self.batch_size
                ].cuda()
                scores = self.model.forward(input_batch)
                num_examples = int(scores.size(0) / (self.neg_ratio+1))
                scores_reshaped = scores.view(num_examples, self.neg_ratio+1)
                l = torch.zeros(num_examples).long().cuda()
                loss = loss_f(scores_reshaped, l)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                b_begin += self.batch_size
                bar.update(input_batch.shape[0])
                bar.set_postfix(
                    loss=f'{loss.item():.3f}'
                )
        return loss


class TEOptimizer(object):
    def __init__(
            self, model_name: str, model: TKBCModel,
            optimizer: optim.Optimizer, batch_size: int = 256,
            verbose: bool = True
    ):
        self.model_name = model_name
        self.model = model
        self.neg_ratio = model.neg_ratio
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.verbose = verbose
    
    def epoch(self, examples: torch.LongTensor):
        actual_examples = examples[torch.randperm(examples.shape[0]), :]
        with tqdm.tqdm(total=examples.shape[0], unit='ex', disable=not self.verbose) as bar:
            bar.set_description(f'train loss')
            b_begin = 0
            while b_begin < examples.shape[0]:
                input_batch = actual_examples[
                    b_begin:b_begin + self.batch_size
                ].cuda()
                loss = self.model.log_rank_loss(input_batch) 
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                if self.model_name == 'ATISE':
                    self.model.regularization_embeddings()

                b_begin += self.batch_size
                bar.update(input_batch.shape[0])
                bar.set_postfix(
                    loss=f'{loss.item():.3f}'
                )
        return loss