import torch
from torch import nn, Tensor
import os
import re
from collections import defaultdict
from tqdm import tqdm
import pandas as pd
import yaml
import multiprocessing as mp

import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from matplotlib.path import Path
from scipy.spatial import ConvexHull
from scipy.spatial.distance import directed_hausdorff
from scipy.stats import wasserstein_distance
from geotorch.sphere import uniform_init_sphere_ as unif_sphere
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

from src.sumformer import *
from eval_utils import *



    

def generate_unit_directions(d, k):
    """Generate k random unit directions in d-dimensional space."""
    directions = np.random.randn(k, d)
    return directions / np.linalg.norm(directions, axis=1, keepdims=True)

def epsilon_kernel_approximation(points, epsilon, k = None):
    """Compute an epsilon-kernel approximation of the convex hull."""
    if k is None:
        k = int(np.ceil(2 * np.pi / np.arccos(1 - epsilon)))  # Number of directions

    d = points.shape[1]  # Dimension of space
    
    directions = generate_unit_directions(d, k)
    
    extreme_indices = set()
    for d in directions:
        projections = points @ d  # Project points onto direction
        max_idx = np.argmax(projections)  # Find max projection (extreme point)
        min_idx = np.argmin(projections)  # Find min projection (extreme point)
        extreme_indices.add(max_idx)
        extreme_indices.add(min_idx)
    
    return points[list(extreme_indices)]

def plot_results(original_points, approx_points):
    """Plot the original set and the epsilon-kernel convex hull approximation."""
    plt.scatter(original_points[:, 0], original_points[:, 1], color='gray', alpha=0.5, label='Original Points')
    plt.scatter(approx_points[:, 0], approx_points[:, 1], color='red', label='Epsilon-Kernel Points')
    
    hull = scipy.spatial.ConvexHull(approx_points)
    for simplex in hull.simplices:
        plt.plot(approx_points[simplex, 0], approx_points[simplex, 1], 'r-')
    
    plt.legend()
    plt.show()

def eval_ellipse_baseline(epsilon, raw_data, in_dim, ptset_size = None, k = None):

    ptsets = np.array([raw_data[i][:,:-1] for i in range(len(raw_data))])

    if ptset_size == None:
        ptset_size = ptsets.shape[-2]
    dataloader, gt = npz_to_batches(raw_data, 128)

    device = 'cpu'
    
    
    chull = []
    gt_hulls = []
    outputs = []

    for points in ptsets:
        approx_points = epsilon_kernel_approximation(points, epsilon, k)
        chull.append(approx_points)

    
    
    for batch in gt:
        gt_hulls += batch
        # gt_hulls += [tensor.cpu().detach().numpy() for tensor in batch]


    
    dir_width = avg_err(chull, ptsets, in_dim=in_dim)
    wasserstein_dist = avg_wasserstein_nd(chull, gt_hulls)

    return dir_width, wasserstein_dist


def main():

    thin_ellipses = np.load('../../../../../data/oren/coreset/data/3d_ellipse_50.npy')
    circles = np.load('../../../../../data/oren/coreset/data/3d_uniform_50.npy')
    boxes = np.load('../../../../../data/oren/coreset/data/3d_rectangle_50.npy')
    modelnet = np.load('../../../../../data/oren/coreset/data/subsampled_modelnet_coreset.npy')
    
    # datasets = [circles]

    ## Evaluating models
    errs = {}
    ellipse_errs = {}
    rectangle_errs = {}
    modelnet_errs = {}

    output_dims = [8, 16, 25, 32, 50, 64, 128, 150]
    
    for od in tqdm(output_dims):
     
        modelnet_errs[od] = eval_ellipse_baseline(epsilon=0.75, raw_data = modelnet, in_dim=3, ptset_size=None, k = od)



        errs[od] = eval_ellipse_baseline(epsilon=0.75, raw_data = circles,in_dim=3, ptset_size=None, k = od)
        ellipse_errs[od] = eval_ellipse_baseline(epsilon=0.75, raw_data = thin_ellipses, in_dim=3, ptset_size=None, k = od)
        # rectangle_errs[od] = eval_ellipse_baseline(epsilon=0.75, raw_data = boxes, in_dim=3, ptset_size=None, k = od)


    try:
        data = {
        'model': list(errs.keys()),
        'ball directional error': [val[0] for val in errs.values()],
        'ball wasserstein dist': [val[1] for val in errs.values()],
        'ellipse directional error': [val[0] for val in ellipse_errs.values()],
        'ellipse wasserstein dist': [val[1] for val in ellipse_errs.values()],
        # 'rectangle directional error': [val[0] for val in rectangle_errs.values()],
        # 'rectangle wasserstein dist': [val[1] for val in rectangle_errs.values()],
        'modelnet directional error': [val[0] for val in modelnet_errs.values()],
        'modelnet wasserstein dist': [val[1] for val in modelnet_errs.values()]
        }
        
        df = pd.DataFrame(data)
        df.to_csv('/data/oren/coreset/out/full_baseline_chull_results.csv', index = False)


    except Exception as e:

        print(e)
        # out_file = open("/data/oren/coreset/out/tf_modelnet_errs.json", "w")
        # json.dump(errs, out_file)
        # out_file.close()

        # out_file = open("/data/oren/coreset/out/3d_circle_errs.json", "w")
        # json.dump(errs, out_file)
        # out_file.close()

        # out_file = open("/data/oren/coreset/out/3d_ellipse_errs.json", "w")
        # json.dump(ellipse_errs, out_file)
        # out_file.close()

        # out_file = open("/data/oren/coreset/out/3d_box_errs.json", "w")
        # json.dump(ellipse_errs, out_file)
        # out_file.close()

if __name__ == "__main__":
    main()