import torch
from torch import nn, Tensor
import os
import re
from collections import defaultdict
from tqdm import tqdm
import pandas as pd
import yaml
import multiprocessing as mp

import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from matplotlib.path import Path
from scipy.spatial import ConvexHull
from scipy.spatial.distance import directed_hausdorff
from scipy.special import softmax
from scipy.stats import wasserstein_distance
from geotorch.sphere import uniform_init_sphere_ as unif_sphere
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

from src.sumformer import *
from eval_utils import *
import argparse



    

def generate_unit_directions(d, k):
    """Generate k random unit directions in d-dimensional space."""
    directions = np.random.randn(k, d)
    return directions / np.linalg.norm(directions, axis=1, keepdims=True)

def epsilon_kernel_approximation(points, epsilon, k = None):
    """Compute an epsilon-kernel approximation of the convex hull."""
    if k is None:
        k = int(np.ceil(2 * np.pi / np.arccos(1 - epsilon)))  # Number of directions

    d = points.shape[1]  # Dimension of space
    
    directions = generate_unit_directions(d, k)
    
    extreme_indices = set()
    for d in directions:
        projections = points @ d  # Project points onto direction
        max_idx = np.argmax(projections)  # Find max projection (extreme point)
        min_idx = np.argmin(projections)  # Find min projection (extreme point)
        extreme_indices.add(max_idx)
        extreme_indices.add(min_idx)
    
    return points[list(extreme_indices)]

def rho_epsilon(A, eps):
    """Apply the rho_eps transformation as defined."""
    w = A.max() - A.min()
    M = A.max()
    return np.maximum(0, A + eps * w - M)

# def relaxed_epsilon_kernel(P, eps, k = None):
#     n, d = P.shape
#     if not torch.is_tensor(P):
#         P = torch.tensor(P)
#     if k is None:
#         s = int(np.ceil(1 / eps**d))
#     else:
#         s = k
#     Omega = generate_unit_directions(d, s)

#     Q = []

#     for u in Omega:
#         projections = P @ u
#         w = projections.max() - projections.min()
#         M = projections.max()

#         rho = projections - M + eps * w
#         rho = F.relu(rho)
        
#         rho = rho_epsilon(projections, eps)

#         # tau = 10 * eps  # Temperature coupled to epsilon
#         # weights = F.softmax(rho / (10 * eps), dim=0)
#         weights = F.normalize(rho, p = 1.0, dim = 0)
        
#         q_u = weights @ P
#         Q.append(q_u)

#     return np.array(Q)

def relaxed_epsilon_kernel(P, eps, k=None):
    if not torch.is_tensor(P):
        P = torch.tensor(P, dtype=torch.float32)

    n, d = P.shape
    if k is None:
        s = int(np.ceil(1 / eps**d))
    else:
        s = k

    Omega = generate_unit_directions(d, s)  # shape: [s, d]
    Omega = torch.tensor(Omega, dtype=P.dtype, device=P.device)

    # Projections: [s, n]
    projections = (P @ Omega.T).T  # shape: [s, n]

    w = projections.max(dim=1).values - projections.min(dim=1).values  # [s]
    M = projections.max(dim=1).values  # [s]

    # Broadcast: projections - M[:, None] + eps * w[:, None]
    rho = projections - M[:, None] + eps * w[:, None]
    rho = F.relu(rho)

    # Apply rho_epsilon row-wise (per direction)
    # Assuming rho_epsilon can take batched input:
    rho = rho_epsilon(rho, eps)  # shape: [s, n]

    # Normalize per direction (dim=1)
    weights = F.normalize(rho, p=1.0, dim=1)  # [s, n]

    # Weighted sum: [s, n] @ [n, d] => [s, d]
    Q = weights @ P

    return Q.cpu().numpy()

def plot_results(original_points, approx_points):
    """Plot the original set and the epsilon-kernel convex hull approximation."""
    plt.scatter(original_points[:, 0], original_points[:, 1], color='gray', alpha=0.5, label='Original Points')
    plt.scatter(approx_points[:, 0], approx_points[:, 1], color='red', label='Epsilon-Kernel Points')
    
    hull = scipy.spatial.ConvexHull(approx_points)
    for simplex in hull.simplices:
        plt.plot(approx_points[simplex, 0], approx_points[simplex, 1], 'r-')
    
    plt.legend()
    plt.show()

def eval_ellipse_baseline(epsilon, raw_data, in_dim, ptset_size = None, k = None):

    ptsets = np.array([raw_data[i][:,:-1] for i in range(len(raw_data))])

    if ptset_size == None:
        ptset_size = ptsets.shape[-2]
    dataloader, gt = npz_to_batches(raw_data, 128)

    device = 'cpu'
    
    
    chull = []
    gt_hulls = []
    outputs = []

    for points in ptsets:
        approx_points = relaxed_epsilon_kernel(points, epsilon, k)
        # approx_points = epsilon_kernel_approximation(points, epsilon, k)
        chull.append(approx_points)

    
    
    for batch in gt:
        gt_hulls += batch
        # gt_hulls += [tensor.cpu().detach().numpy() for tensor in batch]


    
    dir_width = avg_err(chull, ptsets, in_dim=in_dim)
    # wasserstein_dist = avg_wasserstein_nd(chull, gt_hulls)

    return dir_width#, wasserstein_dist


def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--epsilon', type=float)
    args = parser.parse_args()

    eps = args.epsilon

    thin_ellipses = np.load('../../../../../data/oren/coreset/data/5d_ellipse_500_test.npy')
    circles = np.load('../../../../../data/oren/coreset/data/5d_uniform_500_test.npy')
    gauss = np.load('../../../../../data/oren/coreset/data/single_gauss_5d_test.npy')
    mixed_gauss = np.load('../../../../../data/oren/coreset/data/mix_gauss_5d_test.npy')
    mixed = np.load(f'/data/oren/coreset/data/mixed_5d_500_test.npy')
    manifold = np.load('/data/oren/coreset/data/5d_manifold_ellipse_test.npy')


    errs = {}
    ellipse_errs = {}
    rectangle_errs = {}
    gauss_errs = {}
    mixed_gauss_errs = {}
    mixed_errs = {}
    manifold_errs = {}

    output_dims = [16, 32, 64, 100, 200, 300, 400, 500]
    
    for od in tqdm(output_dims):
     



        errs[od] = eval_ellipse_baseline(epsilon=eps, raw_data = circles,in_dim=5, ptset_size=None, k = od)
        ellipse_errs[od] = eval_ellipse_baseline(epsilon=eps, raw_data = thin_ellipses, in_dim=5, ptset_size=None, k = od)
       
        gauss_errs[od] = eval_ellipse_baseline(epsilon=eps, raw_data = gauss, in_dim=5, ptset_size=None, k = od)
        mixed_gauss_errs[od] = eval_ellipse_baseline(epsilon=eps, raw_data = mixed_gauss, in_dim=5, ptset_size=None, k = od)
        mixed_errs[od] = eval_ellipse_baseline(epsilon=eps, raw_data = mixed, in_dim=5, ptset_size=None, k = od)
        manifold_errs[od] = eval_ellipse_baseline(epsilon=eps, raw_data = manifold, in_dim=5, ptset_size=None, k = od)


    try:
        data = {
        'model': list(errs.keys()),
        'circle directional error': [val[0] for val in errs.values()],
        'circle std': [val[1] for val in errs.values()],
        'ellipse directional error': [val[0] for val in ellipse_errs.values()],
        'ellipse std': [val[1] for val in ellipse_errs.values()],
        'gaussian directional error': [val[0] for val in gauss_errs.values()],
        'gaussian std': [val[1] for val in gauss_errs.values()],
        'mixture gaussian directional error': [val[0] for val in mixed_gauss_errs.values()],
        'mixture gaussian std': [val[1] for val in mixed_gauss_errs.values()],
        'manifold directional error': [val[0] for val in manifold_errs.values()],
        'manifold std': [val[1] for val in manifold_errs.values()],
        'mixed directional error': [val[0] for val in mixed_errs.values()],
        'mixed std': [val[1] for val in mixed_errs.values()]
        }
        
        df = pd.DataFrame(data)
        df.to_csv(f'/data/oren/coreset/out/5d_eps{eps}_relaxed_baseline_chull_results.csv', index = False)

        # df.to_csv(f'/data/oren/coreset/out/5d_exact_baseline_chull_results.csv', index = False)


    except Exception as e:

        print(e)
        # out_file = open("/data/oren/coreset/out/tf_modelnet_errs.json", "w")
        # json.dump(errs, out_file)
        # out_file.close()

        # out_file = open("/data/oren/coreset/out/3d_circle_errs.json", "w")
        # json.dump(errs, out_file)
        # out_file.close()

        # out_file = open("/data/oren/coreset/out/3d_ellipse_errs.json", "w")
        # json.dump(ellipse_errs, out_file)
        # out_file.close()

        # out_file = open("/data/oren/coreset/out/3d_box_errs.json", "w")
        # json.dump(ellipse_errs, out_file)
        # out_file.close()

if __name__ == "__main__":
    main()