import enum

import dataclasses
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from omegaconf import DictConfig
import numpy as np
from torch.autograd import grad
from torch.autograd.functional import hessian
import time

from . import models, optimizers

allowed_optimizers = {
    "dfo": optimizers.DerivativeFreeOptimizer,
    "random": optimizers.RandomOptimizer,
    "langevin": optimizers.LangevinOptimizer,
}

@dataclasses.dataclass
class ImplicitPolicy:
    """An implicit conditional EBM trained with an InfoNCE objective."""

    model: nn.Module
    optimizer: torch.optim.Optimizer
    scheduler: torch.optim.lr_scheduler._LRScheduler
    stochastic_optimizer: optimizers.DerivativeFreeOptimizer
    device: torch.device
    steps: int
    gradient_penalty: bool

    @staticmethod
    def initialize(
        model_config: DictConfig,
        optim_config: optimizers.OptimizerConfig,
        stochastic_optim_config: optimizers.StochasticOptimizerConfig,
        device_type: str,
        cnn: bool,
        stochastic_optimizer_type: str,
        gradient_penalty: bool,
    ):
        device = torch.device(device_type if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")

        if cnn:
            model = models.EBMConvMLP(config=model_config)
        else:
            model = models.EBM(config=model_config)
        model.to(device)

        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=optim_config.learning_rate,
            weight_decay=optim_config.weight_decay,
            betas=(optim_config.beta1, optim_config.beta2),
        )

        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=optim_config.lr_scheduler_step,
            gamma=optim_config.lr_scheduler_gamma,
        )
        
        assert stochastic_optimizer_type in allowed_optimizers
        stochastic_optimizer = allowed_optimizers[stochastic_optimizer_type].initialize(
            stochastic_optim_config,
            device_type,
        )

        return ImplicitPolicy(
            model=model,
            optimizer=optimizer,
            scheduler=scheduler,
            stochastic_optimizer=stochastic_optimizer,
            device=device,
            steps=0,
            gradient_penalty=gradient_penalty
        )

    def training_step(
        self, input: torch.Tensor, target: torch.Tensor
    ) -> dict:
        self.model.train()
        # input = input.to(self.device) # 2d float
        # target = target.to(self.device) 

        # Generate N negatives, one for each element in the batch: (B, N, D).
        negatives = self.stochastic_optimizer.sample(input, self.model)

        # Merge target and negatives: (B, N+1, D).
        targets = torch.cat([target.unsqueeze(dim=1), negatives], dim=1)
        targets.requires_grad = self.gradient_penalty

        # Generate a random permutation of the positives and negatives.
        permutation = torch.rand(targets.size(0), targets.size(1)).argsort(dim=1)
        targets = targets[torch.arange(targets.size(0)).unsqueeze(-1), permutation]

        # Get the original index of the positive. This will serve as the class label
        # for the loss.
        ground_truth = (permutation == 0).nonzero()[:, 1].to(self.device)

        # For every element in the mini-batch, there is 1 positive for which the EBM
        # should output a low energy value, and N negatives for which the EBM should
        # output high energy values.
        energy = self.model(input, targets)

        # Interpreting the energy as a negative logit, we can apply a cross entropy loss
        # to train the EBM.
        logits = -1.0 * energy
        loss = F.cross_entropy(logits, ground_truth)

        M = 1 # from IBC paper
        if self.gradient_penalty:
            negatives_idx = (permutation != 0).nonzero().to(self.device)
            dE_dy = grad(energy.sum(), targets, create_graph=True)[0][negatives_idx[:, 0], negatives_idx[:, 1]] # select only the gradients of the sampled points
            grad_linf = torch.norm(dE_dy, p=torch.inf, dim=1)
            penalty = (torch.maximum(grad_linf - M, torch.zeros_like(grad_linf).to(self.device)) ** 2).sum()
            loss += penalty
        self.optimizer.zero_grad(set_to_none=True)
        loss.backward()
        self.optimizer.step()
        self.scheduler.step()

        self.steps += 1

        return {
            "loss": loss.item(),
            "lr": self.scheduler.get_last_lr()[0],
        }

    @torch.no_grad()
    def evaluate(
        self, input: torch.Tensor, target: torch.Tensor
    ) -> dict:
        self.model.eval()
        input = input.to(self.device)
        target = target.to(self.device)
        out = self.stochastic_optimizer.infer(input, self.model)
        mse = F.mse_loss(out, target).item()

        return {"test/mse": mean_mse}


    def predict(self, input: torch.Tensor) -> torch.Tensor:
        self.model.eval()
        return self.stochastic_optimizer.infer(input.to(self.device), self.model) # returns the single best action


    def get_action(self, input: np.array) -> np.array:
        if len(input.shape) == 1:
            input = input[None]
        return self.predict(input).squeeze().cpu().numpy()

    def update_stats(self, *args):
        self.model.update_stats(*args)

    def _log_partition(self, x: torch.Tensor) -> torch.Tensor:
        log_uniform_density = - np.log(np.prod(np.diff(self.stochastic_optimizer.bounds, axis=0)))
        K = 2 ** 11 # any more and drops below 1 it/s
        iters = 10
        def T(y, log_prob, eta, steps=10):
            """
            y: the current state of the Markov chain
            log_prob: the unnormalized log probabilities of the stationary distribution of the Markov chain
            eta: log_prob = log(uniform) * (1-eta) - ebm(y) * eta
            """
            ts = 0
            with torch.no_grad():
                for j in range(steps):
                    y_prime = torch.clamp(y + torch.randn_like(y) * 0.01, bounds[0, :], bounds[1, :])
                    log_a = (1 - eta) * log_uniform_density - eta * self.model(x, y_prime) - ((1 - eta) * log_uniform_density - eta * self.model(x, y))
                    update_idx = torch.log(torch.rand_like(log_a)) < log_a
                    # ti = time.time()
                    # update_idx = update_idx.cpu()
                    # y = y.cpu()
                    # y_prime = y_prime.cpu()
                    y.index_put_(tuple(update_idx.nonzero().t()), y_prime[update_idx])
                    # y[update_idx] = y_prime[update_idx]
                    # y = y.to(self.device)
                #     tf = time.time()
                #     ts += tf - ti
                # print("Update time", ts, y.size(), y.device, y_prime.device)
                return y, self.model(x, y)

        y = torch.cat([self.stochastic_optimizer.sample(x, self.model, uniform=True, inference=True) for _ in range(max(K // self.stochastic_optimizer.inference_samples, 1))], dim=1)
        bounds = torch.as_tensor(self.stochastic_optimizer.bounds).to(y.device)
        
        etas = np.geomspace(1e-5, 1, num=iters)
        with torch.no_grad():
            energies_y = self.model(x, y)
        log_w = 0
        # ts = 0
        for i in range(1, len(etas)):
            log_eta_y = (1 - etas[i]) * log_uniform_density - etas[i] * energies_y
            log_etaprev_y = (1 - etas[i - 1]) * log_uniform_density - etas[i - 1] * energies_y
            log_w += log_eta_y - log_etaprev_y
            # t0 = time.time()
            y, energies_y = T(y, log_eta_y, etas[i])
            # t1 = time.time()
            # ts += t1 - t0
        # print("Total transition:", ts)
        # log_uniform_density = - np.log(np.prod(np.diff(self.stochastic_optimizer.bounds, axis=0)))
        # K = 2 ** 11
        # iters = 10
        # y = torch.cat([self.stochastic_optimizer.sample(x, self.model, uniform=True, inference=True) for _ in range(max(K // self.stochastic_optimizer.inference_samples, 1))], dim=1)
        # etas = np.geomspace(1e-5, 1, num=iters)
        # energies_y = self.model(x, y)
        # log_w = 0
        # for i in range(1, len(etas)):
        #     log_eta_y = (1 - etas[i]) * log_uniform_density - etas[i] * energies_y
        #     log_etaprev_y = (1 - etas[i - 1]) * log_uniform_density - etas[i - 1] * energies_y
        #     log_w += log_eta_y - log_etaprev_y
        #     y = self.stochastic_optimizer.langevin_step(x, y, lambda x, y: (1 - etas[i]) * log_uniform_density - etas[i] * self.model(x, y), iters=10)
        #     energies_y = self.model(x, y)
        log_Z = torch.logsumexp(log_w, dim=1) - np.log(K)
        return log_Z

    def infogain(self, input: np.array) -> np.array:
        if len(input.shape) == 1:
            input = input[None]
        input_np = input
        input = torch.from_numpy(input).float().to(self.device)
        log_Z = self._log_partition(input)
        # return log_y
        
        # uniform_samples = self.stochastic_optimizer.sample(input, self.model, uniform=True)
        pi_samples = self.stochastic_optimizer.sample(input, self.model)
        uniform_density = 1 / np.prod(np.diff(self.stochastic_optimizer.bounds, axis=0))
        with torch.no_grad():
            pi_energies = self.model(input, pi_samples)
            E_energy = pi_energies.mean(dim=1)
            # uniform_energies = self.model(input, uniform_samples)
            # log_partition = torch.logsumexp(-uniform_energies, dim=1) - np.log(uniform_samples.size(1) * uniform_density)
        # print(log_partition)
        # print(E_energy)
        # print(uniform_density)
        # print(uniform_energies.mean(dim=1))
        self.cache_states = input_np
        best_idxs = pi_energies.argmin(dim=-1)
        self.cache_actions = pi_samples[torch.arange(pi_samples.size(0)), best_idxs, :].squeeze().cpu().numpy()
        # print(self.cache_states, self.cache_actions)
        KL_div = -E_energy - log_Z - np.log(uniform_density)
        print("log Z =", log_Z.item(), "log y =", log_y.item(), "DKL =", KL_div.item())
        return KL_div.detach().squeeze().cpu().numpy()

    def policy_uncertainty(self, input: torch.Tensor):
        if len(input.shape) == 1:
            input = input[None]
        input = input.float().to(self.device)
        log_Z = self._log_partition(input)
        y = self.predict(input)
        energies = self.model(input, y.unsqueeze(1)).squeeze()
        log_y = -energies - log_Z
        return torch.exp(-log_y).detach().squeeze().cpu().numpy()
