from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pytorch_lightning as pl
import exp_utils as PQ


class ICNN(nn.Module):
    def __init__(self, n_units, activation=F.relu_):
        super().__init__()
        self.W = nn.ParameterList([nn.Parameter(torch.Tensor(l, n_units[0])) for l in n_units[1:]])
        self.U = nn.ParameterList([nn.Parameter(torch.Tensor(n_units[i + 1], n_units[i]))
                                   for i in range(1, len(n_units) - 1)])
        self.bias = nn.ParameterList([nn.Parameter(torch.Tensor(l)) for l in n_units[1:]])
        self.activation = activation
        self.reset_parameters()

    def reset_parameters(self):
        # copying from PyTorch Linear
        for W in self.W:
            nn.init.kaiming_uniform_(W, a=5**0.5)
        for U in self.U:
            nn.init.kaiming_uniform_(U, a=5**0.5)
        for i, b in enumerate(self.bias):
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W[i])
            bound = 1 / (fan_in**0.5)
            nn.init.uniform_(b, -bound, bound)

    def forward(self, x):
        z = F.linear(x, self.W[0], self.bias[0])
        z = self.activation(z)

        for W, b, U in zip(self.W[1:-1], self.bias[1:-1], self.U[:-1]):
            z = F.linear(x, W, b) + F.linear(z, F.softplus(U)) / U.shape[0]
            z = self.activation(z)

        z = F.linear(x, self.W[-1], self.bias[-1]) + F.linear(z, F.softplus(self.U[-1])) / self.U[-1].shape[0]
        z = F.softplus(z).squeeze(-1)
        return z - 1


class Vnet(nn.Module):
    def __init__(self, net, normalizer):
        super().__init__()
        self.net = net
        self.normalizer = normalizer

    def forward(self, states):
        return (self.net(self.normalizer(states))**2).sum(dim=-1)
        # return (self.net(states)**2).sum(dim=-1)


class StableDynamics(pl.LightningModule):
    LOG_SQRT_2PI = np.log(2 * np.pi) / 2
    n_batches_per_epoch: int

    class FLAGS(PQ.BaseFLAGS):
        batch_size = 256
        weight_decay = 0.0001
        lr = 0.001

    def __init__(self, f, dim_state, alpha, *, buf, buf_dev, name):
        # \dot V(s) <= -alpha V(s)
        super().__init__()
        self.V = ICNN([dim_state, 256, 256, 1])
        # import rl_utils
        # self.V = Vnet(rl_utils.MLP([dim_state, 256, 256, 16]), f.normalizer)
        self.f = f
        self.alpha = alpha
        self.buf = buf
        self.buf_dev = buf_dev
        self.name = name

    def project(self, s, fs):
        with torch.enable_grad():
            s.requires_grad_(True)
            Vs = self.V(s)
            grad_Vs = torch.autograd.grad(Vs.sum(), s, create_graph=True)[0]
        coef = F.relu((grad_Vs * fs).sum(dim=-1) + self.alpha * Vs) / (grad_Vs**2).sum(dim=-1).clamp(min=1e-6)
        coef = coef[..., None]

        # for _ in range(3):
        #     new_Vs = self.V(s + fs - grad_Vs * coef)
        #     invalid = new_Vs > Vs[..., None] * self.alpha
        #     if invalid.sum().item() == 0:
        #         break
        #     coef = coef / (1. + invalid.to(torch.float32))

        fs = fs - grad_Vs * coef
        return fs, Vs

    def forward(self, s, a, det=True):
        fs = self.f(s, a, det)
        if det:
            return s + self.project(s, fs - s)[0]
        else:
            mean, _ = self.project(s, fs.mean - s)
            return torch.distributions.Normal(s + mean, fs.stddev)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.FLAGS.lr, weight_decay=self.FLAGS.weight_decay)
        return optimizer

    def _compute(self, batch):
        fs = self.f(batch['state'], batch['action'], det=False)
        mean, Vs = self.project(batch['state'], fs.mean - batch['state'])
        predictions = torch.distributions.Normal(batch['state'] + mean, fs.stddev)
        targets = batch['next_state']
        nll = -predictions.log_prob(targets).mean()
        return locals()

    def training_step(self, batch, batch_idx):
        results = self._compute(batch)
        loss = results['nll'] + self.f.log_std_loss()
        self.log(f'{self.name}/training_loss', loss.item(), on_step=False, on_epoch=True)
        return {
            'loss': loss,
        }

    def validation_step(self, batch, batch_idx):
        results = self._compute(batch)
        loss = results['nll'] + self.f.log_std_loss()
        self.log(f'{self.name}/val_loss', loss.item(), on_step=False, on_epoch=True)
        return {
            'loss': loss,
        }

    def test(self, policy, state, horizon=1000):
        states = []
        policy.to(self.device)
        state = state.to(self.device)
        for _ in range(horizon):
            states.append(state)
            action = policy(state)
            state = self(state, action, det=True)
        states = torch.stack(states)

        Vs = self.V(states)
        print(Vs[::horizon // 10])
        # breakpoint()
