require(gganimate)
require(ggplot2)
require(ggthemes)
require(ggforce)
library(gridExtra)
library(plotly)
library(geosphere)
require(reshape2)
library(viridis)

library(torchvision)
library(torchdatasets)
library(torch)
library(luz)


library(cooltools) # Mollweide proj


## convert angular representation of S^2 into 3D euclidean space coordinates
RtoY <- function(angles)
{
  torch_stack(list(torch_sin(angles[,1])*torch_cos(angles[,2]),
                   torch_sin(angles[,1])*torch_sin(angles[,2]),
                   torch_cos(angles[,1])), 2)
}
## nontorch version for later object manipulation
YtoR_plain <- function(Y)
{
  cbind(acos(Y[,3]), atan2(Y[,2],Y[,1]))
}

## Bound angles to the appropriate scale
angle_bounding <- function(R)
{
  #longitude <- R[,2] %% (2*pi)
  latitude <- R[,1] %% (2*pi)
  longitude <- sapply(1:length(latitude), function(i)
  {
      return(R[i,2] %% (2*pi))
  })
  latitude <- sapply(latitude, function(lat)
  {
    if (lat >= pi)
    {
      return(pi - (lat - pi))
    } else {
      return(lat)
    }
  })
  cbind(latitude, longitude)
}

## essentially implements a mercator projection based on longitude and latitude (the angles)
## helper function to translate angles to correct domain
angle_merc_translate <- function(R)
{
  longitude <- R[,2] %% (2*pi)
  latitude <- R[,1] %% (2*pi)
  longitude <- sapply(1:length(latitude), function(i)
  {
    if (R[i,1] >= pi | R[i,1] < 0)
    {
      return((R[i,2] + pi) %% (2*pi))
    } else {
      return(R[i,2] %% (2*pi))
    }
  })
  latitude <- sapply(latitude, function(lat)
  {
    if (lat >= pi)
    {
      return(pi - (lat - pi))
    } else {
      return(lat)
    }
  })
  cbind(latitude-(pi/2), longitude)
}
## long0 center of map
Rto2D <- function(R, long0=0, method='mercator')
{
  R <- angle_merc_translate(R)
  longitude <- R[,2]
  latitude <- R[,1]
  if (method == 'mercator')
  {
    cutoff <- 1.553343 ## 89 degree cutoff used in traditional maps to avoid issues with large distortions close to pole
    lat_cut <- latitude
    lat_cut[latitude < -cutoff | latitude > cutoff] <- NA
    mapi <- cbind((longitude - long0) %% (2*pi)  - pi, atanh(sin(lat_cut)))
  } else if (method == 'mollweide') {
    mapi <- mollweide(longitude, latitude, lon0 = long0)
  } else {
    mapi <- 'Unknown projection method'
  }
  return(mapi)
}
plot_mercat_2D <- function(angles, labels=NULL, long0=0, pointsize=.5)
{
  mercator_proj <- Rto2D(angles, long0 = long0)
  cutoff <- 1.553343 ## 89 degree cutoff used in traditional maps to avoid issues with large distortions close to pole
  
  if (!is.null(labels))
  {
    gg <- ggplot(data.frame(MER1=mercator_proj[,1],
                            MER2=mercator_proj[,2],
                            cluster=as.factor(labels)),
                 aes(x=MER1, y=MER2, color=cluster))
  } else {
    gg <- ggplot(data.frame(MER1=mercator_proj[,1],
                            MER2=mercator_proj[,2]),
                 aes(x=MER1, y=MER2))
  }
  gg <- gg + theme_minimal() + geom_point(alpha=.7, size=pointsize) + scale_color_viridis_d() +
    theme(line = element_blank(), axis.text = element_blank()) +
    geom_rect(data=data.frame(xmin = -pi, xmax = pi, ymin = -atanh(sin(cutoff)), ymax = atanh(sin(cutoff))), inherit.aes=F,
              aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax), color='black', alpha=.01) +
    coord_fixed()
  return(gg)
}
plot_moll_2D <- function(angles, labels=NULL)
{
  moll_proj <- Rto2D(angles, long0 = pi, method='mollweide')
  if (!is.null(labels))
  {
    gg <- ggplot(data.frame(MOL1=moll_proj[,1],
                            MOL2=moll_proj[,2],
                            cluster=as.factor(labels)),
                 aes(x=MOL1, y=MOL2, color=cluster)) +
      geom_ellipse(aes(x0 = 0, y0 = 0, a = 2*sqrt(2), b = sqrt(2), angle = 0), color='black')
  } else {
    gg <- ggplot(data.frame(MOL1=moll_proj[,1],
                            MOL2=moll_proj[,2]),
                 aes(x=MOL1, y=MOL2)) +
      geom_ellipse(aes(x0 = 0, y0 = 0, a = 2*sqrt(2), b = sqrt(2), angle = 0), color='black')
  }
  gg <- gg + theme_minimal() + geom_point(alpha=.7, size=.5) + scale_color_viridis_d() +
    theme(line = element_blank(), axis.text = element_blank()) +
    coord_fixed()
  return(gg)
}



## Compute visualization-optimal rotation of 3D points

# Idea: compute a grid of rotations (corresponding to longitude and latitude), take the one with smallest sum of squared latitude values

optimal_rot_vis <- function(Y, granularity=40)
{
  
  latd <- seq(-pi/2, pi/2, pi/granularity)
  longtd <- seq(0, pi, pi/granularity)
  angle_grid <- data.frame(latd=rep(latd, length(longtd)), longtd=rep(longtd, each=length(latd)))
  
  
  grid_eval <- lapply(1:nrow(angle_grid), function(i)
  {
    rot_mat <- matrix(c(cos(angle_grid$longtd[i]), -sin(angle_grid$longtd[i]), 0,
                        sin(angle_grid$longtd[i]), cos(angle_grid$longtd[i]), 0,
                        0, 0, 1),
                      nrow=3)
    rot_mat <- rot_mat %*% matrix(c(cos(angle_grid$latd[i]), 0, -sin(angle_grid$latd[i]),
                                    0, 1, 0,
                                    sin(angle_grid$latd[i]), 0, cos(angle_grid$latd[i])),
                                  nrow=3)
    rot_Y <- Y %*% rot_mat
    
    lat_res <- abs(acos(rot_Y[,3]) - (pi/2))
    list(ssl=sum(lat_res**2),
         rotation=rot_mat,
         Y_rotated=rot_Y)
  })
  
  best_idx <- which.min(sapply(grid_eval, function(x){x$ssl}))
  
  grid_eval[[best_idx]]$Y_rotated
}


