import torch
import torch.nn.functional as F
import os
from utils.MCS import MCS

class MCSTSW():
    def __init__(self, ncomp, dcomp, initK, ntrees=200, nlines=5, p=1, delta=1, fixK=True, eps=1e-6, device="cuda"):
        """
        Class for computing the TW distance between two point clouds
        Args:
            ntrees: Number of trees
            nlines: Number of lines per tree
            p: level of the norm
            delta: negative inverse of softmax temperature for distance based mass division
            device: device to run the code, follow torch convention
        """
        self.ntrees = ntrees
        self.device = device
        self.nlines = nlines
        self.p = p
        self.delta = delta
        self.eps = eps
        self.ncomp = ncomp
        self.dcomp = dcomp

        self.K = initK
        if not fixK:
            self.K = torch.nn.Parameter(self.K)

        self.mcs = MCS(N=self.ncomp, M=self.dcomp, K=self.K, device=self.device)

        if 'COMPILE' in os.environ and os.environ['COMPILE'] in ['1', 'TRUE']:
            self.stw_concurrent_lines = torch.compile(self.stw_concurrent_lines)
            self.project = torch.compile(self.project)
            self.get_mass_and_coordinate = torch.compile(self.get_mass_and_coordinate)


    def __call__(self, X, Y, root, intercept, intercept_index):
        '''
        Args:
            X: (..., A, N*M) tensor, point cloud from distribution 1
            Y: (..., A, N*M) tensor, point cloud from distribution 2
            intercept: (..., ntrees, nlines, d) tensor, intercepts of the trees
        Returns:
            stw: (...) tensor, the STSW distance between X and Y
        '''
        assert torch.isfinite(X).all() and torch.isfinite(Y).all()
        X = X.to(self.device)
        Y = Y.to(self.device)
        *batch, A, N, M = X.shape
        X = self.mcs._ensure_2d(X, self.ncomp, self.dcomp)
        Y = self.mcs._ensure_2d(Y, self.ncomp, self.dcomp)

        # root, intercept, intercept_index = self.mcs.generate_trees_frames(self.ntrees, self.nlines, batch=batch)
        combined_axis_coordinate, mass_X, mass_Y = self.get_mass_and_coordinate(X, Y, root, intercept, intercept_index)
        mcstsw = self.stw_concurrent_lines(mass_X, mass_Y, combined_axis_coordinate)[0]

        return mcstsw

    def stw_concurrent_lines(self, mass_X, mass_Y, combined_axis_coordinate):
        """
        Args:
            mass_X: (..., num_trees, num_lines, 2 * num_points)
            mass_Y: (..., num_trees, num_lines, 2 * num_points)
            combined_axis_coordinate: (..., num_trees, 2 * num_points)
        Returns:
        """
        *batch, num_trees, num_lines, _ = mass_X.shape
        coord_sorted, indices = torch.sort(combined_axis_coordinate, dim=-1)  # (..., ntrees, 2*npoints)
        indices = indices.unsqueeze(-2).repeat(*[1]*(len(batch)+1), num_lines, 1)  # (..., ntrees, nlines, 2*npoints)

        # generate the cumulative sum of mass
        mass_X_sorted = torch.gather(mass_X, -1, indices)  # (..., ntrees, nlines, 2*npoints)
        mass_Y_sorted = torch.gather(mass_Y, -1, indices)  # (..., ntrees, nlines, 2*npoints)
        sub_mass = mass_X_sorted - mass_Y_sorted  #   # (..., ntrees, nlines, 2*npoints)
        sub_mass_cumsum = torch.cumsum(sub_mass, dim=-1)
        sub_mass_target_cumsum = sub_mass + torch.sum(sub_mass, dim=-1, keepdim=True) - sub_mass_cumsum #(..., ntrees, nlines, 2*npoints)

        ### compute edge length
        edge_length = torch.diff(coord_sorted, prepend=torch.zeros(*coord_sorted.shape[:-1], 1, device=coord_sorted.device), dim=-1)
        edge_length = edge_length.unsqueeze(-2) #(..., ntrees, 1, 2*npoints)
        
        # compute TW distance
        subtract_mass = (torch.abs(sub_mass_target_cumsum) ** self.p) * edge_length
        subtract_mass_sum = torch.sum(subtract_mass, dim=[-1,-2])

        tw = torch.mean(subtract_mass_sum, dim=-1).clamp(min=1e-12) ** (1/self.p)
        return tw, sub_mass_target_cumsum, edge_length

    def get_mass_and_coordinate(self, X, Y, root, intercept, intercept_index):
        '''
        Args:
            X: (..., npoints, N, M) tensor, point cloud from distribution 1
            Y: (..., npoints, N, M) tensor, point cloud from distribution 2
            intercept: (..., ntrees, nlines, M) tensor, intercepts of the trees
            intercept_index: (..., ntrees, nlines)
        Returns:
            combined_axis_coordinate: (..., num_trees, 2 * num_points)
            mass_X: (..., num_trees, num_lines, 2 * num_points)
            mass_Y: (..., num_trees, num_lines, 2 * num_points)

            for the last dimension
            0, 1, 2, ...., N -1 are from distribution 1 (X)
            N, N + 1, ...., 2N -1 are from distribution 2 (Y)
        '''
        *batch, N, _, _ = X.shape
        mass_X, axis_coordinate_X = self.project(X, root=root, intercept=intercept, intercept_index=intercept_index)  # (..., ntrees, nlines, N) and (..., ntrees, N)
        mass_Y, axis_coordinate_Y = self.project(Y, root=root, intercept=intercept, intercept_index=intercept_index)  # (..., ntrees, nlines, N) and (..., ntrees, N)

        mass_X = torch.cat((mass_X, torch.zeros((*batch, mass_X.shape[-3], mass_X.shape[-2], N), device=self.device)), dim=-1)  # (..., num_trees, num_lines, 2 * num_points)
        mass_Y = torch.cat((torch.zeros((*batch, mass_Y.shape[-3], mass_Y.shape[-2], N), device=self.device), mass_Y), dim=-1)  # (..., num_trees, num_lines, 2 * num_points)
        combined_axis_coordinate = torch.cat((axis_coordinate_X, axis_coordinate_Y), dim=-1)  # (..., num_trees, 2 * num_points)
        return combined_axis_coordinate, mass_X, mass_Y


    def project(self, input, root, intercept, intercept_index):
        """
        getting the mass and coordinate on the lines
        Args:
            input: (..., N, d)
            root: (..., ntrees, d)
            intercept: (..., ntrees, nlines, d)
        
        Returns:
            mass_input: (..., ntrees, nlines, N)
            axis_coordinate: (..., ntrees, N)
        """
        proj_coord = self.mcs.distance(root, input)  # (..., ntrees, N)
        distance = self.mcs.distance_points_lines(input, root, intercept, intercept_index)
        weight = -self.delta * distance
        mass_input = F.softmax(weight, dim=-2) / input.shape[-3]
        
        return mass_input, proj_coord

    def generate_trees_frames(self):
        return self.mcs.generate_trees_frames(self.ntrees, self.nlines)


def mcstsw(X, Y, ncomp, dcomp, K, ntrees=250, nlines=4, p=2, delta=2, device='cuda'):
    TW_obj = MCSTSW(ncomp=ncomp, dcomp=dcomp, initK=K, ntrees=ntrees, nlines=nlines, p=p, delta=delta, device=device)
    root, intercepts, intercept_indices = TW_obj.mcs.generate_trees_frames(nlines=nlines, ntrees=ntrees)
    mcstswd = TW_obj(X, Y, root, intercepts, intercept_indices)
    return mcstswd

