from time import time
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from tqdm import tqdm

from src.methods.dfl_abstract import DFL
from src.solvers.solver import Solver
from src.methods.models.layers.surrogates.lancer_surrogate import LancerSurrogateLayer
from src.methods.models.layers.feed_forward import FeedForwardLayer
from src.methods.models.layers.scalers.standardizer import DeStandardizer, Standardizer

from src.utils.strings import *


class Lancer(DFL):

    def __init__(self, network: torch.nn.Module, lr: float, surrogate: torch.nn.Module, t: int | None,
                 surrogate_batch_size: int, surrogate_lr: float, standardize_costs: bool, reset_surrogate: bool,
                 solver: Solver | None = None, problem_params: dict[str, np.ndarray] | None = None,
                 apply_early_stopping: bool = True, early_stopping_epochs: int = 10,
                 destandardizer: DeStandardizer | None = None, pretrain_epochs: int = 0,
                 device: torch.device = torch.device("cpu"), name: str = "SFGE"):

        super().__init__(name, network, lr, solver, problem_params, apply_early_stopping, early_stopping_epochs,
                         destandardizer, pretrain_epochs, device)

        self._surrogate = surrogate
        self._t = t
        self._surrogate_batch_size = surrogate_batch_size
        self._surrogate_lr = surrogate_lr
        self._standardize_costs = standardize_costs
        self._reset_surrogate = reset_surrogate

        self._surrogate_optimizer = torch.optim.Adam(self._surrogate.parameters(), lr=self._surrogate_lr)

        self._mse_loss = torch.nn.MSELoss()

        self._buffer_replay_y = []
        self._buffer_replay_y_hat = []
        self._buffer_replay_cost = []

    def _train_procedure(self, x: np.ndarray, y: np.ndarray, z: np.ndarray, cost: np.ndarray,
                         x_val: np.ndarray, y_val: np.ndarray, z_val: np.ndarray, cost_val: np.ndarray,
                         epochs: int, batch_size: int = 64, time_limit: float | None = None) -> int:

        dataset = TensorDataset(torch.Tensor(x), torch.Tensor(y), torch.Tensor(cost))

        start_time = time()

        for epoch in range(epochs):

            loader_surrogate = DataLoader(dataset, batch_size=batch_size, shuffle=True)
            loader_predictor = DataLoader(dataset, batch_size=batch_size, shuffle=True)

            epoch_loss_regret = self._fill_buffer_replay(loader_surrogate)
            self._surrogate_train()

            epoch_loss_surrogate = self._predictor_train(loader_predictor)

            stop = self._early_stopping_check(x_val, y_val, z_val, cost_val)
            print("Epoch {} - Train loss regret: {} - Train loss surrogate {} - Val regret: {}".format(epoch + 1,
                                                                                                       epoch_loss_regret,
                                                                                                       epoch_loss_surrogate,
                                                                                                       self._early_stopping_value))

            if stop or (time_limit is not None and time() - start_time > time_limit):
                return epoch + 1

        return epochs

    def _surrogate_train(self) -> None:

        if self._reset_surrogate:
            self._surrogate.initialize()
            self._surrogate_optimizer = torch.optim.Adam(self._surrogate.parameters(), lr=self._surrogate_lr)

        self._network.train(False)
        self._surrogate.train(True)

        batch_counter = 0

        y = np.array(self._buffer_replay_y)
        y_hat = np.array(self._buffer_replay_y_hat)
        cost = np.array(self._buffer_replay_cost)

        if self._standardize_costs:
            cost = Standardizer.standardize(cost)

        surrogate_dataset = TensorDataset(torch.Tensor(y), torch.Tensor(y_hat), torch.Tensor(cost))
        surrogate_loader = DataLoader(surrogate_dataset, batch_size=self._surrogate_batch_size, shuffle=True)

        for _y, _y_hat, _cost in tqdm(surrogate_loader):
            _y = _y.to(self._device)
            _y_hat = _y_hat.to(self._device)
            _cost = _cost.to(self._device)

            self._surrogate_optimizer.zero_grad()

            costs_hat = self._surrogate([_y_hat, _y])

            loss = self._mse_loss(costs_hat.squeeze(), _cost.squeeze())
            loss = torch.mean(loss)
            loss.backward()

            self._surrogate_optimizer.step()

            batch_counter += 1

            if self._t is not None and batch_counter >= self._t:
                break

    def _predictor_train(self, loader: DataLoader) -> float:

        self._network.train(True)
        self._surrogate.train(False)

        surrogate_epoch_loss = 0.0
        epoch_elem_count = 0
        batch_counter = 0

        for _x, _y, _optimal_cost in tqdm(loader):
            _x = _x.to(self._device)
            _y = _y.to(self._device)
            _optimal_cost = _optimal_cost.to(self._device)
            epoch_elem_count += len(_x)

            self._optimizer.zero_grad()

            y_hat_normalized = self._network(_x)
            y_hat = self._destandardizer(y_hat_normalized) if self._destandardizer else y_hat_normalized

            loss = self._surrogate([y_hat, _y]) * self._mm
            surrogate_epoch_loss += torch.sum(loss - (_optimal_cost * self._mm)).item()
            loss = torch.mean(loss)
            loss.backward()

            self._optimizer.step()

            batch_counter += 1

            if self._t is not None and batch_counter >= self._t:
                break

        return surrogate_epoch_loss / epoch_elem_count

    def _fill_buffer_replay(self, loader: DataLoader) -> float:

        with torch.no_grad():

            epoch_loss = 0.0
            epoch_elem_count = 0

            for _x, _y, _optimal_cost in tqdm(loader, desc="Computing regrets"):
                _x = _x.to(self._device)
                _y = _y.to(self._device)
                _optimal_cost = _optimal_cost.to(self._device)

                y_hat_normalized = self._network(_x)
                y_hat = self._destandardizer(y_hat_normalized) if self._destandardizer else y_hat_normalized

                regrets = []
                for i in range(len(y_hat)):
                    y_hat_i = y_hat[i].detach().numpy()
                    y_i = _y[i].detach().numpy()
                    x_i = _x[i].detach().numpy()

                    sol_hat_i, _ = self._solver.solve(x_i, y_hat_i, self._problem_params)
                    metrics_i = self._solver.compute_metrics(y_i, sol_hat_i, self._problem_params)
                    cost_hat_i = metrics_i[TOTAL_COST]
                    optimal_cost_i = _optimal_cost[i].detach().numpy()
                    regret_i = self._mm * (cost_hat_i - optimal_cost_i)
                    regrets.append(regret_i)

                    self._buffer_replay_y.append(y_i)
                    self._buffer_replay_y_hat.append(y_hat_i)
                    self._buffer_replay_cost.append(cost_hat_i)

                regrets = np.array(regrets)
                epoch_loss += np.sum(regrets)
                epoch_elem_count += len(regrets)

            return epoch_loss / epoch_elem_count

    def reset(self) -> None:
        super().reset()
        self._surrogate.initialize()
        self._surrogate_optimizer = torch.optim.Adam(self._surrogate.parameters(), lr=self._surrogate_lr)

        self._buffer_replay_y = []
        self._buffer_replay_y_hat = []
        self._buffer_replay_cost = []

    @staticmethod
    def build_from_config(parameters: dict, input_dim: int, output_dim: int) -> DFL:

        network = FeedForwardLayer(input_dim, output_dim, parameters["hidden_units"])
        surrogate = LancerSurrogateLayer(output_dim, parameters["lancer_hidden_units"])
        destandardizer = DeStandardizer()

        model = Lancer(network, parameters["lr"], surrogate, parameters["t"], parameters["surrogate batch size"],
                       parameters["surrogate lr"], parameters["standardize costs"], parameters["reset surrogate"],
                       destandardizer=destandardizer, pretrain_epochs=parameters["pretrain_epochs"],
                       name=parameters["name"])

        return model
