from abc import ABC, abstractmethod

import math
import numpy as np
import torch
from torch import nn as nn

from PCVFR.model.model_utiities import gradients, laplace, custom_hessian
from PCVFR.model.modules import FCLayer, FourierEmbedding, NVidiaAttention, AttentionLayer
from PCVFR.model.sampler import IndependentSampler
from PCVFR.util.system_util import to_numpy, to_tensor


class HeatNetwork(nn.Module, ABC):
    """
    Base Network, contains pdes, losses, the sampler.
    Subclass this network and implement "build_network" for using it.
    """

    def __init__(self, domain, num_units=32, num_layers=5, dtype=torch.float32, dropout=0., device="cuda", **kwargs):
        super(HeatNetwork, self).__init__()
        self.in_features = 3
        self.out_features = 1
        self.dtype = dtype

        self.net = self.build_network(num_units=num_units, num_layers=num_layers, **kwargs)
        self.sobol_engine = torch.quasirandom.SobolEngine(dimension=3, scramble=True)

        self.domain = domain
        self.dropout = nn.Dropout(dropout)

        # MCMC Sampler
        log_p = lambda x: to_numpy(self.log_prob(torch.as_tensor(x, dtype=dtype).to(device)))
        self.sampler = IndependentSampler(
            log_p_unnormalized=log_p,
            num_chains=100,
            data_dim=3,
            min_lim=np.array([self.domain['MIN_T'], self.domain['MIN_X'], self.domain['MIN_X']]),
            max_lim=np.array([self.domain['MAX_T'], self.domain['MAX_X'], self.domain['MAX_X']])
        )

    @abstractmethod
    def build_network(self, num_units, num_layers, **kwargs) -> torch.nn.Module:
            ...

    @property
    def device(self):
        return next(self.parameters()).device

    def log_prob(self, x):
        T = self(x)

        return torch.log(T)

    def forward(self, x):
        T = self.net(x)
        T = self.dropout(T)
        return T

    def allow_input_derivatives(self, coords):
        """
        allows to take dSerivative w.r.t. input
        """
        return coords.clone().detach().requires_grad_(True).to(self.device)

    def sample_uniform(self, num_uniform_samples):
        samples = self.sobol_engine.draw(n=num_uniform_samples, dtype=self.dtype).to(self.device)
        t_diff = self.domain['MAX_T'] - self.domain['MIN_T']
        x_diff = self.domain['MAX_X'] - self.domain['MIN_X']
        samples[:, 0] = samples[:, 0] * t_diff + self.domain['MIN_T']
        samples[:, 1] = samples[:, 1] * x_diff + self.domain['MIN_X']
        samples[:, 2] = samples[:, 2] * x_diff + self.domain['MIN_X']
        return samples

    def sample_signal_domain(self, num_samples, fraction_mcmc=0.8, burnin=100):
        if fraction_mcmc <= 0.:
            samples = self.sample_uniform(num_samples)
        else:
            num_mcmc_samples = int(fraction_mcmc * num_samples)
            num_uniform_samples = int((1 - fraction_mcmc) * num_samples)

            samples = self.sample_uniform(num_uniform_samples)

            t_init = np.random.uniform(self.domain['MIN_T'], self.domain['MAX_T'], self.sampler.num_chains)
            center = 0.5 * (self.domain['MAX_X'] + self.domain['MIN_X'])
            x_init = np.random.normal(np.repeat(center, self.sampler.num_chains, axis=0),
                                      np.repeat(.1, self.sampler.num_chains, axis=0))
            y_init = np.random.normal(np.repeat(center, self.sampler.num_chains, axis=0),
                                                  np.repeat(.1, self.sampler.num_chains, axis=0))
            xyt_init = np.stack((t_init, x_init, y_init), -1)

            samples_mcmc = self.sampler.sample_chains(x_init=xyt_init, burnin=burnin,
                                                      n_samples=burnin + math.ceil(num_mcmc_samples / self.sampler.num_chains))

            samples_mcmc = to_tensor(samples_mcmc.reshape(-1, 3), device=self.device)
            samples_mcmc = samples_mcmc[-num_mcmc_samples:]

            samples = torch.cat([samples, samples_mcmc], 0)

        return self.allow_input_derivatives(samples)


    def loss_ic(self, x, T):
        T_ic_nn = self.forward(x)  # Initial condition
        assert T.shape == T_ic_nn.shape

        # Loss function for the initial condition
        loss_ics = ((T_ic_nn - T) ** 2).mean()

        return loss_ics


    def loss_pde(self, x):
        T = self(x)
        k = self.domain['K']
        assert T.shape == (len(x), 1)
        #
        # Gradients and partial derivatives
        dT_g = gradients(T, x)[0]  # Gradient [T_t, T_x]
        T_t = dT_g[:, 0]  # Partial derivatives T_t, T_x

        hess = custom_hessian(T, x)[0]
        hess = torch.squeeze(hess)
        lapl = hess[:,1,1] + hess[:,2,2]

        # Loss function for the Euler Equations
        f = ((T_t - k * lapl) ** 2).mean()

        return f

class HeatTanhNet(HeatNetwork):
    """
    Basically the one in the reference implementation.
    Multiple Tanh fully connected layers, nothing else.
    """

    def build_network(self, num_units, num_layers, **kwargs) -> torch.nn.Module:
        fc_net = FCLayer(in_features=self.in_features,
                         out_features=self.out_features,
                         num_hidden_layers=num_layers, hidden_features=num_units,
                         nonlinearity="tanh",
                         outermost_activation=torch.nn.Sigmoid)

        return fc_net

class HeatSirenNet(HeatNetwork):
    """
    Sinuosidal activations. (SIREN)
    """

    def build_network(self, num_units, num_layers, **kwargs) -> torch.nn.Module:
        fc_net = FCLayer(in_features=self.in_features,  # self.attention_layer.out_features,
                         out_features=self.out_features,
                         num_hidden_layers=num_layers, hidden_features=num_units,
                         nonlinearity="sine",
                         outermost_activation=torch.nn.Sigmoid,
                         sine_frequency=1)
        return fc_net
