import torch
import math
import cupy as cp
from cupyx.scipy.sparse.linalg import gmres, LinearOperator
import os
os.environ["DDE_BACKEND"] = "pytorch"
import deepxde as dde
os.environ["CUDA_HOME"] = "/usr/local/cuda-11.8"
os.environ["CUDA_PATH"] = "/usr/local/cuda-11.8"
os.environ["PATH"] = "/usr/local/cuda-11.8/bin:" + os.environ["PATH"]
os.environ["LD_LIBRARY_PATH"] = "/usr/local/cuda-11.8/lib64:" + os.environ.get("LD_LIBRARY_PATH", "")
import pykeops
# pykeops.clean_pykeops()
import ot
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

class MaternGP:
    """
    Matern Gaussian Process
    If dim=1, then returns samples of size 'size'.
    If dim=2, then returns samples of size 'size x size'.
    """
    def __init__(self, size, dim=1, alpha=2, tau=3, sigma=None, device=None):
        self.device = device
        self.dim = dim
        
        if sigma is None:
            sigma = tau**(0.5*(2*alpha - self.dim))
            
        if self.dim == 1:
            i_max = size
            i = torch.arange(start=1, end=i_max+1, step=1, device=device)
            self.basis_idx = i
    
            self.sqrt_eig = sigma*(((math.pi**2)*(i**2) + tau**2)**(-alpha/2.0))
            
            self.size = (size,)
            
            x = torch.linspace(0, 1, size, device=device) 
            self.grid = x
            # Basis Functions
            self.phi = torch.sin(math.pi * i[:, None] * x[None, :])
        
        if self.dim == 2:
            i_max = size
            j_max = size
            i = torch.arange(start=1, end=i_max+1, step=1, device=device)
            j = torch.arange(start=1, end=j_max+1, step=1, device=device)
            I, J = torch.meshgrid(i,j,indexing='ij')
        
            self.sqrt_eig = sigma*(((math.pi**2)*(I**2) + (math.pi**2)*(J**2) + tau**2)**(-alpha/2.0))
            
            self.size = (size, size)
            
            x = torch.linspace(0, 1, size, device=device)
            y = torch.linspace(0, 1, size, device=device)
            self.grid = torch.meshgrid(x,y,indexing='xy')
            
            # Basis Functions
            self.phi_x = torch.sin(math.pi * i[:, None] * x[None, :])  # (i_max, M)
            self.phi_y = torch.sin(math.pi * j[:, None] * y[None, :])  # (j_max, M)
        
    def sample(self, N, z=None):
        """
        Returns N samples from the Matern GP
        """
        if z is None:
            z = torch.randn(N,*self.size, device=self.device)
        coeff = self.sqrt_eig*z  # if dim==1, N x size. if dim==2, N x size x size
        
        if self.dim == 1:
            u = math.sqrt(2) * (coeff @ self.phi)  # N x size
            
        if self.dim == 2:            
            u = torch.einsum('nij,ik,jl->nkl', coeff, self.phi_x, self.phi_y)
            u = 2 * u  # (N, size, size)
        
        return u
    
class RTE:
    def __init__(self, L, dr, dtheta, ep_index, device, n_projections=1024):
        self.L = L
        self.dr = dr
        self.device = device
        self.x = torch.arange(0, L+dr, self.dr, device=device)
        self.dtheta = dtheta
        self.theta0 = torch.arange(-math.pi + dtheta/2, math.pi - dtheta/2 + dtheta, dtheta, device=device)
        self.ep_index = ep_index
        self.epsilon = 1.0 / (2 ** (ep_index))
        self.Nr = len(self.x)
        self.Nt = len(self.theta0)
        # Trunk
        # Create 2D meshgrids
        xx, yy = torch.meshgrid(self.x, self.x, indexing='xy')
        # Reshape coordinates in the same way
        x_vec = xx.reshape(-1)
        y_vec = yy.reshape(-1).flipud()
        # Stack into a single (Nr^2, 2) tensor
        self.trunk = torch.stack([x_vec, y_vec], dim=1)
        self.Lip = None
        self.R = None
        self.mean_test_m2 = None
        self.true_model_zero = None
        self.c_test = None
        self.a_test = None
        self.u_test = None
    
    def MAB_multi(self,p,a):
        e = torch.ones(self.Nt,device=self.device,dtype=p.dtype)
        p = p.view(self.Nt, self.Nr, self.Nr).permute(2, 1, 0).contiguous()

        # Pre-compute sin/cos for all angles
        sinT = torch.sin(self.theta0)
        cosT = torch.cos(self.theta0)

        # Compute gradients for all time steps at once
        Rp_pre_x = (p[:,1:,:] - p[:,:-1,:])/self.dr
        Rp_pre_y = (p[1:,:,:] - p[:-1,:,:])/self.dr

        # Initialize output tensors
        Rp_x = torch.zeros_like(p,device=self.device)
        Rp_y = torch.zeros_like(p,device=self.device)
        boundary_D = torch.zeros_like(p,device=self.device)

        # Vectorized y-direction updates
        mask_sin_neg = sinT <= 0
        mask_sin_pos = sinT > 0

        Rp_y[:-1,:,:] = torch.where(mask_sin_neg.view(1,1,-1), 
                                    sinT.view(1,1,-1) * Rp_pre_y, 
                                    torch.zeros_like(Rp_pre_y))
        Rp_y[1:,:,:] += torch.where(mask_sin_pos.view(1,1,-1),
                                    sinT.view(1,1,-1) * Rp_pre_y,
                                    torch.zeros_like(Rp_pre_y))
        Rp_y[:,0,:] = 0
        Rp_y[:,-1,:] = 0

        # Vectorized x-direction updates  
        mask_cos_pos = cosT >= 0
        mask_cos_neg = cosT < 0

        Rp_x[:,1:,:] = torch.where(mask_cos_pos.view(1,1,-1),
                                   cosT.view(1,1,-1) * Rp_pre_x,
                                   torch.zeros_like(Rp_pre_x))
        Rp_x[:,:-1,:] += torch.where(mask_cos_neg.view(1,1,-1),
                                    cosT.view(1,1,-1) * Rp_pre_x,
                                    torch.zeros_like(Rp_pre_x))
        Rp_x[0,:,:] = 0
        Rp_x[-1,:,:] = 0

        # Boundary conditions (vectorized)
        boundary_D[-1,:,:] = torch.where(mask_sin_neg, p[-1,:,:], boundary_D[-1,:,:])
        boundary_D[0,:,:] = torch.where(mask_sin_pos, p[0,:,:], boundary_D[0,:,:])
        boundary_D[:,0,:] = torch.where(mask_cos_pos, p[:,0,:], boundary_D[:,0,:])
        boundary_D[:,-1,:] = torch.where(mask_cos_neg, p[:,-1,:], boundary_D[:,-1,:])

        # BCp calculation
        p_interior = p[1:, 1:, :]
        p_means_broadcast = torch.mean(p_interior, dim=2, keepdim=True)
        p_centered = p_interior - p_means_broadcast
        a_interior = a[1:, 1:].unsqueeze(-1)
        BCp = torch.zeros_like(p,device=self.device)
        BCp[1:, 1:, :] = a_interior * p_centered

        Mb = self.epsilon*Rp_x + self.epsilon*Rp_y + BCp + boundary_D
        Mb = Mb.permute(2, 1, 0).contiguous().view(-1)

        return Mb

    def SolveRTE(self,a,verbose=False):
        '''Given scattering coefficient 'a' as a matrix, this function
        solves the RTE and returns the solution 'u' integrated over velocity'''
        xx0, yy0 = torch.meshgrid(self.x, self.x, indexing='xy')
        xx, yy, vv = torch.meshgrid(self.x, self.x, self.theta0, indexing='xy')

        boundary_index = torch.zeros(xx.shape,device=self.device)
        boundary_index[1:-1,self.x==0,torch.cos(self.theta0)>0] = 1
        boundary_index[1:-1,torch.isclose(self.x,torch.tensor(self.L,device=self.device)),torch.cos(self.theta0)<0] = 1
        boundary_index[torch.isclose(self.x,torch.tensor(self.L,device=self.device)),1:-1,torch.sin(self.theta0)<0] = 1
        boundary_index[self.x==0,1:-1,torch.sin(self.theta0)>0] = 1
        boundary_index = boundary_index.permute(2, 1, 0).contiguous().view(-1)

        out_index = torch.zeros(xx.shape,device=self.device)
        out_index[1:-1,self.x==0,torch.cos(self.theta0)<0] = 1
        out_index[1:-1,torch.isclose(self.x,torch.tensor(self.L,device=self.device)),torch.cos(self.theta0)>0] = 1
        out_index[torch.isclose(self.x,torch.tensor(self.L,device=self.device)),1:-1,torch.sin(self.theta0)>0] = 1
        out_index[self.x==0,1:-1,torch.sin(self.theta0)<0] = 1
        out_index = out_index.permute(2, 1, 0).contiguous().view(-1)

        # BC: phi(z,v) = g(z)*max(v*n(z),0)
        # Below is the g we will use and only applied to the left inflow boundary
        #     phi(z,v) = (1+0.5*sin(pi*x/L))*max(0, cos(theta0(kt))) 
        phi3 = torch.zeros(self.Nr, self.Nr, self.Nt, device=self.device) # full 3D, but we only use values at boundary inflow
        g = 1+0.5*torch.sin(math.pi*xx0/self.L)
        # for kt in range(Nt):
        #     ang_factor = max(0, torch.cos(theta0[kt]))
        #     # Put this on inflow boundary only; here we set phi inside the whole domain but will only use boundary entries
        #     phi3[:,:,kt] = g * ang_factor
        # Vectorized version of loop above
        ang_factors = torch.clamp(torch.cos(self.theta0), min=0)  # Shape: (Nt,)
        phi3 = g.unsqueeze(-1) * ang_factors  # Broadcasting: (Nr, Nr, 1) * (Nt,) -> (Nr, Nr, Nt)

        # Now make a BC a vector
        phi_vec = torch.zeros(boundary_index.shape,device=self.device)
        phi3_vec = phi3.permute(2, 1, 0).contiguous().view(-1)
        phi_vec[boundary_index==1] = phi3_vec[boundary_index==1]

        # Use GMRES to solve Ax=b where Ax operator A is given with matvec_gpu
        # Convert PyTorch tensor to CuPy array
        phi_vec_cp = cp.asarray(phi_vec.detach())
        def matvec_gpu(p_cp):
            # Convert CuPy array to torch
            p_torch = torch.as_tensor(p_cp, device=self.device)
            result_torch = self.MAB_multi(p_torch, a)
            # Convert back to CuPy array
            result_cp = cp.asarray(result_torch.detach())
            return result_cp
        N = self.Nr * self.Nr * self.Nt
        A_gpu = LinearOperator((N, N), matvec=matvec_gpu, dtype=cp.float64)
        # call CuPy GMRES (works purely on GPU, no host copies)
        f_cp, info = gmres(A_gpu, phi_vec_cp, restart=300, tol=1e-8)
        # Print convergence information
        if verbose:
            if info == 0:
                final_residual = cp.linalg.norm(A_gpu @ f_cp - phi_vec_cp) / cp.linalg.norm(phi_vec_cp)
                print(f"GMRES converged to relative residual {final_residual:.1e}")
            elif info > 0:
                print(f"GMRES reached maximum iterations ({info}) without convergence")
            else:
                print(f"GMRES failed with error code {info}")
        # convert result back to torch if you need it there
        f_torch = torch.as_tensor(f_cp, device=self.device)

        # Convert soln to tensor and integrate over velocity
        u3 = f_torch.view(self.Nt, self.Nr, self.Nr).permute(2, 1, 0).contiguous() # Reshape into tensor using MATLAB ordering
        u_mat = torch.sum(u3, 2) * self.dtheta

        return u_mat
    
    def compute_u(self, a):
        '''
        This function takes N particles of size Nr x Nr, ie 'a' is size [N,Nr,Nr] which are
        discretized functions on the grid. This function then outputs the solution u for each particle
        by solving the Radiative Transfer Equation for a fixed boundary condition.
        The output has size [N,Nr,Nr].
        '''
        N = a.shape[0]
        u = torch.zeros_like(a, device=self.device)
        def process_single(args):
            i, a_i = args
            result = self.SolveRTE(a_i)
            return i, result
        with ThreadPoolExecutor(max_workers=2) as executor:
            futures = {executor.submit(process_single, (i, a[i])): i 
                       for i in range(N)}

            # Create progress bar
            with tqdm(total=N, desc="Processing samples") as pbar:
                for future in as_completed(futures):
                    i, result = future.result()
                    u[i] = result
                    pbar.update(1)  # Update progress bar by 1
                    pbar.set_postfix({'last_completed': f'Sample {i+1}'})
        return u
    
    def train_model(self,a,u):
        N_samples = a.shape[0]

        # Train/test split
        ntrain = int(0.8 * N_samples)
        a_vec = a.reshape(N_samples, self.Nr**2) # Reshape images into long vectors
        u_vec = u.reshape(N_samples, self.Nr**2)
        a_train, a_test = a_vec[:ntrain], a_vec[ntrain:]
        u_train, u_test = u_vec[:ntrain], u_vec[ntrain:]

        # DeepXDE data object
        data = dde.data.TripleCartesianProd(
            X_train=(a_train, self.trunk),
            y_train=u_train,
            X_test=(a_test,  self.trunk),
            y_test=u_test,
        )
        
        # Simpler architecture
        p = 32  
        net = dde.nn.pytorch.DeepONetCartesianProd(
            layer_sizes_branch=[self.Nr**2,100,100,100,64, p], 
            layer_sizes_trunk=[2,100,100,100,64, p],
            activation="relu",
            kernel_initializer="Glorot normal",
        )
        net.to(self.device)  # Ensure the model params live on GPU

        # Compile model
        model = dde.Model(data, net)
        model.compile(
            "adam",
            lr=1e-3,
            loss="mean l2 relative error",
            decay=("inverse time", 1000, 0.5),  # (type, decay_steps, decay_rate)
        )

        print('STARTING MODEL TRAINING')
        # Train model
        losshistory, train_state = model.train(
            iterations=5000,
            display_every=1000,
        )
        
        return model, losshistory, train_state
    
    def c_to_a(self,c):
        '''
        This functions takes in coefficients 'c' and outputs conductivities
        a = exp(sum_i sum_j c_ij 2*sin(i pi x)*sin(j pi y))
        '''
        Ntrunc = c.shape[1]
        y = self.x.clone().detach()
        i = torch.linspace(1,Ntrunc,Ntrunc)
        j = torch.linspace(1,Ntrunc,Ntrunc)
        # [particle q, index i, index j, pos X, pos Y]
        terms = c[:,:,:,None,None]*2*torch.sin(i[None,:,None,None,None]*math.pi*self.x[None,None,None,:,None])*\
                                    torch.sin(j[None,None,:,None,None]*math.pi*y[None,None,None,None,:])
        a = torch.exp(torch.sum(terms,dim=(1,2)))
        
        return a
    
    def c_to_loga(self,c):
        '''
        This functions takes in coefficients 'c' and outputs log-conductivities
        log(a) = sum_i sum_j c_ij 2*sin(i pi x)*sin(j pi y)
        '''
        Ntrunc = c.shape[1]
        y = self.x.clone().detach()
        i = torch.linspace(1,Ntrunc,Ntrunc)
        j = torch.linspace(1,Ntrunc,Ntrunc)
        # [particle q, index i, index j, pos X, pos Y]
        terms = c[:,:,:,None,None]*2*torch.sin(i[None,:,None,None,None]*math.pi*self.x[None,None,None,:,None])*\
                                    torch.sin(j[None,None,:,None,None]*math.pi*y[None,None,None,None,:])
        loga = torch.sum(terms,dim=(1,2))
        
        return loga
    
    def norm(self, a):
        '''
        This function computes the L2 norm of each of the N particles in a distribution. 
        Each particle a_q in a for q=1,...,N is given as a discretized matrix with grid spacing dr
        in both the x and y direction.
        The output is a vector of size N.
        '''
        # Square all elements
        a_squared = a**2  # Shape: (N, Nr_y, Nr_x)

        # Integrate over x (last dimension, columns)
        integral_x = torch.trapezoid(a_squared, dx=self.dr, dim=-1)  # Shape: (N, Nr_y)

        # Integrate over y (now last dimension, rows)
        integral_xy = torch.trapezoid(integral_x, dx=self.dr, dim=-1)  # Shape: (N,)
        
        return torch.sqrt(integral_xy)
    
    def m2(self, a):
        '''
        This function computes the 2nd moment of a distribution made up of N particles. 
        By definition m2(\nu) = integral ||a||^2 d\nu(a)
        Each particle a_q in a for q=1,...,N is given as a discretized matrix with grid spacing dr
        in both the x and y direction.
        '''
        squared_norms = self.norm(a)**2 # Shape: (N,)
        return torch.mean(squared_norms)
    
    def set_parameters(self, Lip, R, mean_test_m2, true_model_zero, c_test, a_test, u_test, a_zero_coef):
        '''
        This function sets important parameters and saves the test distribution.
        '''
        self.Lip = Lip
        self.R = R
        self.mean_test_m2 = mean_test_m2
        self.true_model_zero = true_model_zero
        self.c_test = c_test
        self.a_test = a_test
        self.u_test = u_test
        self.a_zero_coef = a_zero_coef
        
    
    def W2_squared(self, a, b, get_kp=False):
        '''
        This function computes the W2^2 distance between two distributions made up of particles.
        These particles are functions given as matrices discretized over the grid.
        '''
        N_a = a.shape[0]
        N_b = b.shape[0]
        chunk_size = 20
        C = torch.zeros(N_a, N_b, device=self.device)

        for i in range(0, N_a, chunk_size):
            i_end = min(i + chunk_size, N_a)

            # Compute differences for this chunk
            diff = a[i:i_end, None, :, :] - b[None, :, :, :]
            C[i:i_end, :] = self.norm(diff)**2
            
        train_weights = torch.ones(N_a, device=self.device) / N_a  # uniform weights
        test_weights = torch.ones(N_b, device=self.device) / N_b  # uniform weights
        
        if get_kp == True:
            emd2, log = ot.emd2(train_weights, test_weights, C, log=True)
            return emd2, log['u']
        else:
            emd2 = ot.emd2(train_weights, test_weights, C)
            return emd2
    
    def F(self, c_train, u_train, model):
        '''
        This function computes the objective function evaluated at the coefficients c of the particles 'a'.
        This function takes the coefficients c_train, the true solutions u_train (a->u), and the particles
        a_test which are functions as matrices discretized of the grid.
        '''
        a_train = self.c_to_a(c_train)
        term1 = torch.mean(self.norm(u_train - model.net((a_train.reshape(-1, self.Nr**2),self.trunk)).reshape(-1, self.Nr, self.Nr))**2)

        A = 4*(self.Lip+self.R)**4
        my_pred_zero = model.net((self.a_zero_coef.reshape(-1,self.Nr**2),self.trunk)).reshape(-1,self.Nr,self.Nr)
        my_model_zero = self.norm(my_pred_zero[0])**2
        B = (self.Lip+self.R)**2*16*(self.true_model_zero + my_model_zero)
        D = B + A*self.mean_test_m2

        K = self.a_test.shape[0]
        W2_squared_array = torch.zeros(K, device=self.device)
        for k in range(K):
            W2_squared_array[k] = self.W2_squared(a_train,self.a_test[k])
        term2 = torch.sqrt(A*self.m2(a_train)+D)*torch.sqrt(torch.mean(W2_squared_array))
        
        return term1 + term2
    
    def gradient_F(self, c_train,u_train,model):
        """
        Compute gradient of F evaluated at the coefficients c of size [# of samples,Ntrunc,Ntrunc]
        which are the same coefficients of the training data particles.
        """
        c_train_ = c_train.detach().requires_grad_(True)

        # Compute DF in vectorized form
        F_vals = self.F(c_train_,u_train,model)

        # Compute gradients
        grads = torch.autograd.grad(
            outputs=F_vals,
            inputs=c_train_,
            grad_outputs=torch.ones_like(F_vals),  # Provide ones for each output
            create_graph=False,
            retain_graph=False 
        )[0]

        return grads
    
    def computeOODError(self,model):
        K = self.a_test.shape[0]
        OOD_errors = torch.zeros(K, device=self.device)
        for k in range(K):
            test_errors = self.norm(self.u_test[k] - model.net((self.a_test[k].reshape(-1, self.Nr**2),self.trunk)).reshape(-1, self.Nr, self.Nr))**2
            OOD_errors[k] = torch.mean(test_errors)
        mean_OOD_error = torch.mean(OOD_errors)
        return mean_OOD_error
    
    def computeRelativeOODError(self,model):
        K = self.a_test.shape[0]
        num_OOD_errors = torch.zeros(K, device=self.device)
        den_OOD_errors = torch.zeros(K, device=self.device)
        for k in range(K):
            num_test_errors = self.norm(self.u_test[k] - model.net((self.a_test[k].reshape(-1, self.Nr**2),self.trunk)).reshape(-1, self.Nr, self.Nr))**2
            num_OOD_errors[k] = torch.mean(num_test_errors)
            den = self.norm(self.u_test[k])**2
        num_mean_OOD_error = torch.mean(num_OOD_errors)
        den_mean = torch.mean(den)
        return num_mean_OOD_error/den_mean

    def DF(self,c_train,u_train,model):
        '''
        This function computes the functional derivative DF evaluated at each of the particles of the training distribution,
        thus the output will be N matrices (i.e. a tensor).
        '''
        a_train = self.c_to_a(c_train)
        train_error = self.norm(u_train - model.net((a_train.reshape(-1, self.Nr**2),self.trunk)).reshape(-1, self.Nr, self.Nr))**2
        
        A = 4*(self.Lip+self.R)**4
        my_pred_zero = model.net((self.a_zero_coef.reshape(-1,self.Nr**2),self.trunk)).reshape(-1,self.Nr,self.Nr)
        my_model_zero = self.norm(my_pred_zero[0])**2
        B = (self.Lip+self.R)**2*16*(self.true_model_zero + my_model_zero)
        D = B + A*self.mean_test_m2
        
        K = self.a_test.shape[0]
        N = a_train.shape[0]
        W2_squared = torch.zeros(K,device=self.device)
        kp = torch.zeros(K,N,device=self.device)
        for k in range(K):
            W2_squared[k], kp[k] = self.W2_squared(a_train,self.a_test[k],get_kp=True)
        mean_W2_squared = torch.mean(W2_squared)
        mean_kp = torch.mean(kp, dim=0)
        
        factor = torch.sqrt(mean_W2_squared/(A*self.m2(a_train)+D))
        DF = train_error + 0.5*A*self.norm(a_train)**2*factor + mean_kp/factor
        
        return DF
    
    def gradient_DF(self, c_train,u_train,model):
        """
        Compute gradient of DF evaluated at the coefficients c of size [# of samples,Ntrunc,Ntrunc].
        """
        c_train_ = c_train.detach().requires_grad_(True)

        # Compute DF in vectorized form
        DF_vals = self.DF(c_train_,u_train,model)

        # Compute gradients
        grads = torch.autograd.grad(
            outputs=DF_vals,
            inputs=c_train_,
            grad_outputs=torch.ones_like(DF_vals),  # Provide ones for each output
            create_graph=False,
            retain_graph=True 
        )[0]

        return grads        