from __future__ import annotations
import math
import time
from abc import ABC, abstractmethod

import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from tqdm import tqdm

from src.methods.models.layers.scalers.standardizer import DeStandardizer

from src.solvers.solver import Solver
from src.utils.strings import *


class DFL(ABC):

    def __init__(self, name: str, network: torch.nn.Module, lr: float, solver: Solver | None = None,
                 problem_params: dict[str, np.ndarray] | None = None, apply_early_stopping: bool = True,
                 early_stopping_epochs: int = 25, destandardizer: DeStandardizer | None = None,
                 pretrain_epochs: int = 0, device: torch.device = torch.device("cpu")):

        self._name = name
        self._network = network
        self._lr = lr
        self._solver = solver
        self._problem_params = problem_params
        self._apply_early_stopping = apply_early_stopping
        self._early_stopping_epochs = early_stopping_epochs
        self._destandardizer = destandardizer
        self._pretrain_epochs = pretrain_epochs
        self._device = device

        self._optimizer = torch.optim.Adam(self._network.parameters(), lr=self._lr)

        self._training_runtime = 0.0
        self._pre_processing_runtime = 0.0
        self._training_epochs = 0
        self._train_set_size = 0
        self._trained = False

        self._early_stopping_value = None
        self._early_stopping_count = 0

        if self._solver is not None:
            self._mm = 1 if self._solver.is_minimization_problem else -1

    @property
    def network(self) -> torch.nn.Module:
        return self._network

    @property
    def device(self) -> torch.device:
        return self._device

    @property
    def solver(self) -> Solver:
        return self._solver

    @property
    def name(self) -> str:
        return self._name

    @property
    def training_metrics(self) -> dict:

        if not self._trained:
            print("WARNING: Training metrics should be invoked after training")

        training_metrics = {
            "runtime": round(self._training_runtime, 2),
            "pre-training runtime": round(self._pre_processing_runtime, 2),
            "calls": self._solver.calls,
            "avg. calls": round(self._solver.calls / (self._train_set_size * self._training_epochs), 2),
            "epochs": self._training_epochs
        }

        return training_metrics

    def set_y_stats(self, y_mean: np.ndarray | torch.Tensor, y_std: np.ndarray | torch.Tensor) -> None:

        if self._destandardizer is not None:
            self._destandardizer.set_stats(y_mean, y_std)

    def set_solver(self, solver: Solver) -> None:

        self._solver = solver
        self._mm = 1 if self._solver.is_minimization_problem else -1

    def set_problem_params(self, problem_params: dict[str, np.ndarray]) -> None:
        self._problem_params = problem_params

    def reset(self) -> None:
        self._network.initialize()
        self._optimizer = torch.optim.Adam(self._network.parameters(), lr=self._lr)

    def compute_metrics(self, x: np.ndarray, y: np.ndarray, z: np.ndarray, cost: np.ndarray,
                        batch_size: int = 32) -> dict:

        assert len(x) == len(y) == len(z) == len(cost)
        assert self._solver is not None and self._problem_params is not None, "Missing solver"

        n_samples = len(x)

        mse = 0.0
        regrets = []
        rel_regrets = []
        costs = []
        optimal_costs = []
        infeasible_sol_count = 0
        rel_regrets_feasible = []
        runtimes = []

        mm = 1 if self._solver.is_minimization_problem else -1

        y_hat = self._network_inference(x, batch_size)

        for i in range(n_samples):

            mse += np.mean((y_hat[i] - y[i]) ** 2)

            z_hat, runtime = self._solver.solve(x[i], y_hat[i], params=self._problem_params)
            runtimes.append(runtime)

            hat_metrics = self._solver.compute_metrics(y[i], z_hat, self._problem_params)

            optimal_cost = cost[i]
            hat_cost = hat_metrics[TOTAL_COST]
            optimal_costs.append(optimal_cost)
            costs.append(hat_cost)

            regret = mm * (hat_cost - optimal_cost)
            regrets.append(regret)
            if optimal_cost > 0.0:
                relative_regret = regret / optimal_cost
                rel_regrets.append(relative_regret)
                if hat_metrics[FEASIBLE]:
                    rel_regrets_feasible.append(relative_regret)

            if not hat_metrics[FEASIBLE]:
                infeasible_sol_count += 1

        metrics = {
            "mse": float(mse / n_samples),
            "avg. regret": float(np.mean(regrets)),
            "avg. relative regret": float(np.mean(rel_regrets)) if len(rel_regrets_feasible) > 0 else math.nan,
            "avg. cost": float(np.mean(costs)),
            "avg. optimal cost": float(np.mean(optimal_costs)),
            "avg. feasible relative regrets": float(np.mean(rel_regrets_feasible)) if len(rel_regrets_feasible) > 0 else math.nan,
            "avg. runtime": float(np.mean(runtimes)),
            "infeasible solutions ratio": float(infeasible_sol_count / n_samples)
        }

        return metrics

    def pretrain(self, x: np.ndarray, y: np.ndarray, epochs: int = 100, batch_size: int = 32, lr: float = 1e-3) -> None:

        optimizer = torch.optim.Adam(self._network.parameters(), lr=lr)
        mse_loss = torch.nn.MSELoss()

        dataset = TensorDataset(torch.Tensor(x), torch.Tensor(y))
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

        self._network.train(True)

        print("PFL - Pretraining...")

        for epoch in range(epochs):

            epoch_loss = 0.0

            for _x, _y in tqdm(loader):
                _x = _x.to(self._device)
                _y = _y.to(self._device)

                optimizer.zero_grad()

                y_hat = self._network(_x)
                if self._destandardizer is not None:
                    y_hat = self._destandardizer(y_hat)

                loss = mse_loss(_y, y_hat)
                loss = torch.mean(loss)

                loss.backward()

                optimizer.step()

                epoch_loss += loss.item()

            epoch_loss /= len(loader)

            print("Epoch {} - MSE: {}".format(epoch + 1, epoch_loss))

    def train(self, x: np.ndarray, y: np.ndarray, z: np.ndarray, cost: np.ndarray,
              x_val: np.ndarray, y_val: np.ndarray, z_val: np.ndarray, cost_val: np.ndarray,
              epochs: int, batch_size: int = 32, time_limit: float | None = None) -> dict:

        if time_limit is not None:
            assert time_limit > 0.0

        if self._pretrain_epochs > 0:
            self.pretrain(x, y, self._pretrain_epochs, batch_size, self._lr)

        if not self._apply_early_stopping:
            x = np.concat([x, x_val], axis=0)
            y = np.concat([y, y_val], axis=0)
            z = np.concat([z, z_val], axis=0)
            cost = np.concat([cost, cost_val], axis=0)

        self._reset_training_metrics()
        self._solver.reset_calls()

        self._train_set_size = len(x)

        start_time = time.time()
        self._training_epochs = self._train_procedure(x, y, z, cost, x_val, y_val, z_val, cost_val, epochs,
                                                      batch_size, time_limit)
        self._training_runtime = time.time() - start_time
        self._trained = True

        training_metrics = self.training_metrics
        metrics = self.compute_metrics(x, y, z, cost, batch_size) | training_metrics

        return metrics

    def test(self, x: np.ndarray, y: np.ndarray, z: np.ndarray, cost: np.ndarray, batch_size: int = 32) -> dict:

        self._network.train(False)
        self._network.eval()

        self._solver.freeze_calls_count()
        metrics = self.compute_metrics(x, y, z, cost, batch_size)
        self._solver.unfreeze_calls_count()

        return metrics

    def _compute_regret(self, x_i: torch.Tensor, y_hat_i: torch.Tensor,
                        y_i: torch.Tensor, optimal_cost_i: torch.Tensor) -> float:

        assert self._solver is not None and self._problem_params is not None, "Missing solver"

        y_hat_i = y_hat_i.detach().numpy()
        y_i = y_i.detach().numpy()
        x_i = x_i.detach().numpy()

        sol_hat_i, _ = self._solver.solve(x_i, y_hat_i, self._problem_params)
        metrics_i = self._solver.compute_metrics(y_i, sol_hat_i, self._problem_params)
        cost_hat_i = metrics_i[TOTAL_COST]
        optimal_cost_i = optimal_cost_i.detach().numpy()
        regret_i = self._mm * (cost_hat_i - optimal_cost_i)

        return regret_i

    @abstractmethod
    def _train_procedure(self, x: np.ndarray, y: np.ndarray, z: np.ndarray, cost: np.ndarray,
                         x_val: np.ndarray, y_val: np.ndarray, z_val: np.ndarray, cost_val: np.ndarray,
                         epochs: int, batch_size: int = 32, time_limit: float | None = None) -> int:
        pass

    def _network_inference(self, x: np.ndarray, batch_size: int = 32) -> np.ndarray:

        dataset = TensorDataset(torch.Tensor(x))
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

        predictions = []

        with torch.no_grad():
            for x in tqdm(loader):
                x = x[0].to(self._device)
                prediction = self._network(x)
                if self._destandardizer is not None:
                    prediction = self._destandardizer(prediction)
                prediction = torch.squeeze(prediction)
                predictions.extend(list(prediction.to("cpu").numpy()))

        return np.array(predictions)

    def _early_stopping_check(self, x_val: np.ndarray, y_val: np.ndarray,
                              z_val: np.ndarray, cost_val: np.ndarray) -> bool:

        if not self._apply_early_stopping:
            return False

        with torch.no_grad():
            self._solver.freeze_calls_count()
            val_metrics = self.compute_metrics(x_val, y_val, z_val, cost_val)
            self._solver.unfreeze_calls_count()
            regret = val_metrics["avg. regret"]

            if self._early_stopping_value is not None:
                if regret < self._early_stopping_value:
                    self._early_stopping_value = regret
                    self._early_stopping_count = 0
                    return False
                elif self._early_stopping_value == 0.0 or regret/self._early_stopping_value > 1.25:
                    return True
                else:
                    self._early_stopping_count += 1
                    return self._early_stopping_count >= self._early_stopping_epochs

            else:
                self._early_stopping_value = regret
                return False

    def _reset_training_metrics(self) -> None:

        self._training_runtime = 0.0
        self._training_epochs = 0
        self._train_set_size = 0
        self._trained = False
        self._early_stopping_value = None
        self._early_stopping_count = 0

    @staticmethod
    def compute_regret(solver: Solver, problem_params: dict[str, np.ndarray], mm: int,
                       x_i: torch.Tensor, y_hat_i: torch.Tensor,
                       y_i: torch.Tensor, optimal_cost_i: torch.Tensor) -> float:

        assert solver is not None and problem_params is not None, "Missing solver"

        y_hat_i = y_hat_i.detach().numpy()
        y_i = y_i.detach().numpy()
        x_i = x_i.detach().numpy()[0]

        sol_hat_i, _ = solver.solve(x_i, y_hat_i, problem_params)
        metrics_i = solver.compute_metrics(y_i, sol_hat_i, problem_params)
        cost_hat_i = metrics_i[TOTAL_COST]
        optimal_cost_i = optimal_cost_i.detach().numpy()
        regret_i = mm * (cost_hat_i - optimal_cost_i)

        return regret_i

    @staticmethod
    @abstractmethod
    def build_from_config(parameters: dict, input_dim: int, output_dim: int) -> DFL:
        pass
