import torch
from torch import nn, Tensor
import os

import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from matplotlib.path import Path
from scipy.spatial import ConvexHull
from scipy.spatial.distance import directed_hausdorff
from geotorch.sphere import uniform_init_sphere_ as unif_sphere



from src.sumformer import *



def npz_to_batches(raw_data, batch_size=128):

    batch_list = []
    gt_list = []
    in_batch =[]
    in_batch_gt = []
    count = 0

    for i in range(raw_data.shape[0]):
        ptset = raw_data[i][:, :-1]
       
        cvx_hull_idx = np.where(raw_data[i][:, -1] == 1.0)

        in_batch.append(torch.tensor(ptset, dtype=torch.float))
        in_batch_gt.append(torch.tensor(ptset[cvx_hull_idx], dtype=torch.float))
        if (count != 0 and count % (batch_size-1) == 0) or i == raw_data.shape[0] - 1:
            batch = Batch.from_list(in_batch, order = 1)
            batch_list.append(batch)
            gt_list.append(in_batch_gt)
            in_batch = []
            in_batch_gt = []
        count += 1

    return batch_list, gt_list

def get_approx_chull(probabilities, batch):
    hulls = []
    start = 0
    for num in batch.n_nodes:
        end = start + num
        ptset = batch.data[start:end]
        ptset_probs = probabilities.data[start:end]
        hull_approx = torch.mm(ptset_probs.T, ptset)

        hulls.append(hull_approx)
        start = end
    return hulls



def sort_points_by_angle(points):
    centroid = np.mean(points, axis=0)
    
    angles = np.arctan2(points[:, 1] - centroid[1], points[:, 0] - centroid[0])
    
    # Sort points by angles
    return points[np.argsort(angles)]


def directional_width(P, u):
    prods = P @ u
    
    return max(prods) - min(prods)

def directional_err(Q, P, in_dim, n=1000):
    directions = unif_sphere(torch.zeros(n, in_dim)).numpy()
    max_width = -float('inf')
    
    for u in directions:
        num = abs(directional_width(P, u) - directional_width(Q, u))
        denom = directional_width(P, u)
        
        err = num / denom
        
        if err >= max_width:
            max_width = err

    return max_width


def plot_hull(i):
    
    sorted_chull = sort_points_by_angle(chull[i])
    sorted_gt = sort_points_by_angle(gt_hulls[i])
   
    pt_set = raw_data[i][:, :2]
   
    
    sorted_chull = np.vstack([sorted_chull, sorted_chull[0]])



    common_points = []
    for pt in pt_set:
        if np.any(np.sqrt(np.sum(np.square(chull[i] - pt), axis=1)) < 0.1):
            common_points.append(pt)
    common_points = np.array(common_points)

    # Plotting
    plt.figure(figsize=(8, 6))
    
    
    mask = np.array([np.any(np.sqrt(np.sum(np.square(chull[i] - pt), axis=1)) < 0.1) for pt in pt_set])
    plt.plot(sorted_chull[:, 0], sorted_chull[:, 1], 'r-', linewidth=2, label='Approx. Convex Hull', marker = 'x')
    plt.scatter(pt_set[~mask][:, 0], pt_set[~mask][:, 1], color='orange', label='Data Points')
    

    # Plot common points in yellow if there are any
    if len(common_points) > 0:
        plt.scatter(common_points[:, 0], common_points[:, 1], color='black', label='Common Points')

    
    plt.plot(sorted_gt[:, 0], sorted_gt[:, 1], linewidth=2, color='green', label='Ground Truth Hull')

    
    plt.xlabel('X-coordinate')
    plt.ylabel('Y-coordinate')
    plt.title('Estimated Convex Hull (train set)')
        
    plt.legend(loc='lower right', bbox_to_anchor=(1.3, 0))
    plt.show()

   
    # print(f'The estimated hull contains {prop * 100} percent of the points')
    # distance = directed_hausdorff(chull[i], pt_set[gt_hull.vertices])[0]
    # print(f'Hausdorff distance between the approx convex hull and true chull: {distance}')
    print(f'Directional Width Error is {directional_err(sorted_chull, pt_set, 2, n=1000)}')