"""
Action functions for physical theories and toy models.

This module provides action classes for various physical models including:
- Scalar phi^4 theory
- Complex phi^4 theory  
- Hubbard model
- Gaussian mixture models

Each action class implements the `evaluate` method that computes the action
value S(phi) for a given field configuration, where the target distribution
is p(phi) = exp(-S(phi)) / Z.
"""

import math
import logging
import torch
import os
import yaml
from torch import Tensor

logger = logging.getLogger("SESaMo")


class Action:
    """
    Base class for action functions.
    
    Subclasses must implement the `evaluate` method to compute the action
    value S(phi) for a given field configuration phi.
    """
    
    def evaluate(self, phi: Tensor) -> Tensor:
        """
        Compute the action value for the given field configuration.
        
        Parameters
        ----------
        phi : Tensor
            Field configuration tensor.
            
        Returns
        -------
        Tensor
            Action values for each configuration in the batch.
        """
        raise NotImplementedError("evaluate method not implemented")

    def __call__(self, phi: Tensor) -> Tensor:
        """
        Compute the action value for the given field configuration.
        
        Parameters
        ----------
        phi : Tensor
            Field configuration tensor.
            
        Returns
        -------
        Tensor
            Action values for each configuration in the batch.
        """
        return self.evaluate(phi)
    

# ====================================================================================
# PHYSICAL ACTIONS
# ====================================================================================


class ScalarPhi4Action(Action):
    """
    Action for the real scalar phi^4 theory on a 2D lattice.
    
    The action is given by:
        S[phi] = sum_x [ -2*kappa*phi(x)*(phi(x+1) + phi(x+e_t)) 
                        + (1 - 2*lambda)*phi(x)^2 + lambda*phi(x)^4 ]
                + broken * sum_x phi(x)
    
    where kappa is the hopping parameter, lambda is the self-interaction
    coupling, and broken is an optional Z2 symmetry breaking term.
    
    Parameters
    ----------
    kappa : float
        Hopping parameter controlling nearest-neighbor interactions.
    lambd : float
        Self-interaction coupling constant.
    broken : float, optional
        Z2 symmetry breaking parameter. Default is 0.
    """
    
    def __init__(self, kappa: float, lambd: float, broken: float = 0):
        """
        Initialize the scalar phi^4 action.
        
        Parameters
        ----------
        kappa : float
            Hopping parameter.
        lambd : float
            Self-interaction coupling.
        broken : float, optional
            Z2 breaking term. Default is 0.
        """
        self.kappa = kappa
        self.lambd = lambd
        self.broken = broken

        logger.info(f"Initialized Phi4Action with kappa={kappa}, lambd={lambd}, breaking={broken}")

    def evaluate(self, phi:  Tensor) -> Tensor:
        """
        Evaluate the scalar phi^4 action.
        
        Parameters
        ----------
        phi : Tensor
            Field configuration of shape (batch, nt, nx).
            
        Returns
        -------
        Tensor
            Action values of shape (batch,).
            
        Raises
        ------
        ValueError
            If phi does not have shape (batch, nt, nx).
        """
        if len(phi.shape) != 3:
            raise ValueError(f"field has invalid shape, should be (batch, nt, nx), but got {phi.shape}")

        kinetic = (-2 * self.kappa) * phi * (torch.roll(phi, 1, -1) + torch.roll(phi, 1, -2))
        mass = (1 - 2 * self.lambd) * phi ** 2
        inter = self.lambd * phi ** 4
        z2_breaking_term = self.broken * phi.reshape(phi.shape[0], -1).sum(-1)

        return (kinetic + mass + inter).sum(-1).sum(-1) + z2_breaking_term


class ComplexPhi4Action(ScalarPhi4Action):
    """
    Action for the complex phi^4 theory on a 2D lattice.
    
    This extends the scalar phi^4 theory to complex fields phi = phi_1 + i*phi_2,
    where the interaction term couples both field components.
    
    Parameters
    ----------
    kappa : float
        Hopping parameter controlling nearest-neighbor interactions.
    lambd : float
        Self-interaction coupling constant.
    broken : float, optional
        U(1) symmetry breaking parameter. Default is 0.
    """
    
    def __init__(self, kappa: float, lambd: float, broken: float = 0):
        """
        Initialize the complex phi^4 action.
        
        Parameters
        ----------
        kappa : float
            Hopping parameter.
        lambd : float
            Self-interaction coupling.
        broken : float, optional
            U(1) breaking term. Default is 0.
        """
        self.kappa = kappa
        self.lambd = lambd
        self.broken = broken

        logger.info(f"Initialized ComplexPhi4Action with kappa={kappa}, lambd={lambd}, breaking={broken}")

    def evaluate(self, phi: Tensor) -> Tensor:
        """
        Evaluate the complex phi^4 action.
        
        Parameters
        ----------
        phi : Tensor
            Field configuration of shape (batch, 2, nt, nx) where the second
            dimension contains real and imaginary parts.
            
        Returns
        -------
        Tensor
            Action values of shape (batch,).
        """
        kinetic = (-2 * self.kappa) * phi * (torch.roll(phi, 1, -1) + torch.roll(phi, 1, -2))
        mass = (1 - 2 * self.lambd) * phi ** 2
        inter = self.lambd * (phi ** 2).sum(1) ** 2

        return (kinetic.sum(1) + mass.sum(1) + inter).sum(-1).sum(-1)

    def __call__(self, phi: Tensor) -> Tensor:
        """
        Compute the complex phi^4 action with optional U(1) breaking.
        
        Parameters
        ----------
        phi : Tensor
            Field configuration of shape (batch, 2, nt, nx).
            
        Returns
        -------
        Tensor
            Action values of shape (batch,).
            
        Raises
        ------
        ValueError
            If phi does not have shape (batch, 2, nt, nx).
        """
        if len(phi.shape) != 4:
            raise ValueError(f"field has invalid shape, should be (batch, 2, n_t, n_x), but got {phi.shape}")

        N = phi.shape[0]
        main_action = self.evaluate(phi)
        u1_breaking_term = self.broken * phi.reshape(N, -1).sum(-1)

        return main_action + u1_breaking_term





class HubbardAction(Action):
    """
    Hubbard model action for fermionic lattice systems.
    
    Implements the action for the Hubbard model in the spin basis with
    exponential discretization. Supports 2-site and 18-site hexagonal lattices.
    For more details see: https://arxiv.org/pdf/1812.09268
    
    Parameters
    ----------
    u : float, optional
        On-site interaction strength. Default is 18.
    beta : float, optional
        Inverse temperature. Default is 1.
    nt : int, optional
        Number of time slices. Default is 1.
    nx : int, optional
        Number of spatial sites (2 or 18). Default is 2.
    """
    
    def __init__(self, u: float = 18, beta: float = 1, nt: float = 1, nx: float = 2):
        """
        Initialize the Hubbard action.
        
        Parameters
        ----------
        u : float, optional
            On-site coupling constant. Default is 18.
        beta : float, optional
            Inverse temperature. Default is 1.
        nt : int, optional
            Number of time slices. Default is 1.
        nx : int, optional
            Number of spatial sites. Default is 2.
            
        Raises
        ------
        ValueError
            If nx is not 2 or 18.
        """
        self.u = u
        self.beta = beta
        self.nt = nt
        self.nx = nx
        self.dtype_warning_issued = False
        
        if nx not in [2, 18]:
            raise ValueError(f"HubbardAction only implemented for nx=2 or nx=18, but got nx={nx}")
        
        if nx == 18:
            lattice_file = "hubbard_lattices/18_sites_hex.yml"
            with open(os.path.join(os.path.dirname(__file__), lattice_file), 'r') as f:
                self.hopping = yaml.safe_load(f)

        logger.info(f"Initialized HubbardAction with U={u}, beta={beta}")


    def logdet_m(self, phi: Tensor, species: int) -> Tensor:
        """
        Compute the log determinant of the fermion matrix.
        
        Uses the spin-basis and exponential discretization to compute
        the fermion determinant for a given auxiliary field configuration.
        
        Parameters
        ----------
        phi : Tensor
            Auxiliary field configuration of shape (N, Nt, Nx).
        species : int
            Spin species: +1 for spin-up, -1 for spin-down.
            
        Returns
        -------
        Tensor
            Log determinant of the fermion matrix for each configuration.
            
        Raises
        ------
        ValueError
            If species is not +1 or -1.
        """
        if species not in [-1, 1]:
            raise ValueError(f"Species must be +/- 1 but got {species}")
        
        N, nt, nx = phi.shape
        device = phi.device
        dtype = phi.dtype
        
        if phi.dtype != torch.float64 and not self.dtype_warning_issued:
            self.dtype_warning_issued = True
            logger.warning(f"dtype {phi.dtype} may lead to numerical instabilities of the Hubbard action, consider using float64")

        # get the hopping matrix
        exp_kappa = torch.zeros((nx, nx), device=device, dtype=dtype)
        if nx == 2:
            exp_kappa[0, 0] = math.cosh(self.beta / nt)
            exp_kappa[0, 1] = math.sinh(self.beta / nt)
            exp_kappa[1, 0] = math.sinh(self.beta / nt)
            exp_kappa[1, 1] = math.cosh(self.beta / nt)
        elif nx == 18:
            for index, (i,j) in enumerate(self.hopping["adjacency"]):
                exp_kappa[i,j] = self.hopping["hopping"][index] * self.beta/nt
            exp_kappa = torch.matrix_exp(exp_kappa)
        else:
            raise ValueError(f"Fermion matrix not implemented for nx={nx}")
        
        # precompute the exponential of the configuration
        exp_phi = torch.exp(species * phi)

        nx_unit = torch.eye(nx, dtype=dtype, device=device)

        if nx == 2 and nt == 1:
            # initialize the fermion matrix
            m = torch.zeros((N, nt, nx, nt, nx), dtype=dtype, device=device)

            # fill fermion matrix entries
            ts = torch.arange(nt - 1)
            m[:, ts, :, ts, :] = nx_unit
            m[:, 0, :, 0, :] = nx_unit + exp_kappa * exp_phi[:, 0, None, :]
            m = m.reshape(N, nt*nx, nt*nx)
            
        elif nx == 18:
            # compute sausage matrix according to https://arxiv.org/pdf/1812.09268
            sausage = torch.einsum('ij,...j->...ij', exp_kappa, exp_phi.roll(1, dims=1))
            sausage = sausage.flip(dims=[1])  # reverse the order in time direction
            
            # compute matrix multiplications
            m = nx_unit.unsqueeze(0).repeat(N, 1, 1)
            for t in range(nt):
                m = torch.bmm(m, sausage[:, t])
            m += nx_unit.unsqueeze(0)

        else:
            raise ValueError(f"logdet_m not implemented for nx={nx} and nt={nt}")

        return m.logdet()            
    

    def evaluate(self, phi: Tensor) -> Tensor:
        """
        Evaluate the Hubbard action.
        
        Parameters
        ----------
        phi : Tensor
            Auxiliary field configuration of shape (batch, Nt, Nx).
            
        Returns
        -------
        Tensor
            Action values of shape (batch,).
            
        Raises
        ------
        ValueError
            If phi does not have shape (batch, Nt, Nx).
        """
        if len(phi.shape) != 3 or phi.shape[2] != self.nx:
            raise ValueError(f"Input has invalid shape, should be (batch, Nt, {self.nx}), but got {phi.shape}")
        
        # compute u_tilde
        nt = phi.shape[1]
        u_tilde = self.u * self.beta / nt

        # compute the action
        actions = (phi*phi).sum(dim=(1, 2)) / (2 * u_tilde) - self.logdet_m(phi, +1) - self.logdet_m(phi, -1) 
        
        return actions.real
    
    
    def logZ(self) -> float:
        """
        Return the log partition function.
        
        Only implemented for Nx=2, Nt=1, u=18, and beta=1.
        
        Returns
        -------
        float
            Log partition function log(Z).
            
        Raises
        ------
        ValueError
            If u != 18 or beta != 1.
        """
        if self.u != 18 or self.beta != 1:
            raise ValueError("logZ is only defined for u=18 and beta=1")
        
        return 24.6398



# ====================================================================================
# TOY MODELS
# ====================================================================================


class GaussianMixtureAction(Action):
    """
    Action for a Gaussian mixture model arranged on a circle.
    
    Creates n_gaussians Gaussian-shaped modes arranged uniformly on a circle
    of given radius. Useful as a toy model for testing normalizing flows
    with discrete symmetries.
    
    Parameters
    ----------
    n_gaussians : int, optional
        Number of Gaussian modes. Default is 8.
    radius : float, optional
        Radius of the circle where modes are placed. Default is 12.
    broken : float, optional  
        Symmetry breaking parameter. Default is 0.
    """
    
    def __init__(self, n_gaussians: int = 8, radius: float = 12, broken: float = 0):
        """
        Initialize the Gaussian mixture action.
        
        Parameters
        ----------
        n_gaussians : int, optional
            Number of Gaussian modes. Default is 8.
        radius : float, optional
            Radius of the circle where modes are placed. Default is 12.
        broken : float, optional
            Symmetry breaking parameter. Default is 0.
        """
        self.n_gaussians = n_gaussians
        self.radius = radius
        self.broken = broken

        self.centers = self.get_gaussians()

        logger.info(f"Initialized GaussianMixtureAction with n_gaussians={n_gaussians}, radius={radius}, breaking={broken}")

    def get_gaussians(self) -> Tensor:
        """
        Compute the centers of the Gaussian modes.
        
        Returns
        -------
        Tensor
            Tensor of shape (n_gaussians, 2) containing the 2D coordinates
            of each Gaussian center arranged on a circle.
        """
        centers = []

        for i in range(self.n_gaussians):
            centers.append([math.cos(2*math.pi/self.n_gaussians*i), math.sin(2*math.pi/self.n_gaussians*i)])

        return self.radius * torch.tensor(centers)
    

    def evaluate(self, phi: Tensor) -> Tensor:
        """
        Evaluate the Gaussian mixture action.
        
        Parameters
        ----------
        phi : Tensor
            Field configuration of shape (batch, 1, 2).
            
        Returns
        -------
        Tensor
            Action values of shape (batch,).
            
        Raises
        ------
        ValueError
            If phi does not have shape (batch, 1, 2).
        """
        if len(phi.shape) != 3 or phi.shape[1:] != (1,2):
            raise ValueError(f"Input has invalid shape, should be (batch, 1, 2), but got {phi.shape}")

        if self.centers.device != phi.device:
            self.centers = self.centers.to(phi.device)

        return -torch.logsumexp(-0.5*(phi - self.centers[None, None, :]).norm(dim=-1)**2, dim=-1).squeeze(0) + self.broken * phi.sum(dim=(1,2))
    
    
    def logZ(self) -> float:
        """
        Compute the log partition function.
        
        Returns
        -------
        float
            Log partition function log(Z).
        """
        if self.broken == 0:
            return math.log(2*math.pi * self.n_gaussians)
        else:
            phi_k = 2*math.pi / self.n_gaussians * torch.arange(self.n_gaussians)
            return math.log(2*math.pi) + torch.logsumexp(self.broken**2 - self.broken*self.radius*math.sqrt(2) * torch.sin(phi_k + math.pi/4), dim=-1).item()