import torch
from torch import nn, Tensor
import os
import re
from collections import defaultdict
from tqdm import tqdm
import pandas as pd

import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from matplotlib.path import Path
from scipy.spatial import ConvexHull
from scipy.spatial.distance import directed_hausdorff
from scipy.stats import wasserstein_distance
from geotorch.sphere import uniform_init_sphere_ as unif_sphere
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

from src.sumformer import *
from eval_utils import *



def extract_hyperparameters(filepath):
    # Define the regular expression pattern
    pattern = r"(?P<name>[a-z-]+)(?P<value>\d+)"
    
    # Use defaultdict to store each hyperparameter as a list of values
    hyperparameters = defaultdict(list)
    
    # Use finditer to get all matches in the filepath
    matches = re.finditer(pattern, filepath)
    
    # Iterate through matches and store them in the dictionary
    for match in matches:
        name = match.group("name")  # hyperparameter name
        value = int(match.group("value"))  # convert value to integer
        hyperparameters[name].append(value)
    
    return dict(hyperparameters)
    
def eval_ellipse(fp, raw_data):

    params = extract_hyperparameters(fp)

    depth = params['depth-'][0]
    ed = params['-ed-'][0]
    hd = params['-hd-'][0]
    od = params['-od-'][0]
    

    ptsets = [raw_data[i][:,:-1] for i in range(len(raw_data))]
    dataloader, gt = npz_to_batches(raw_data, 128)

    model = ConvexHullNN(input_dim=10, depth = depth, embedding_dim=ed, hidden_dim=hd, output_dim=od)

    device = 'cpu'
    
    state_dict_path = os.path.join(fp, 'final_model.pt')
    
    model.load_state_dict(torch.load(state_dict_path), strict = False)
    model = model.to(device)
    
    chull = []
    gt_hulls = []
    outputs = []
    
    for batch in dataloader:
        batch = batch.to(device)
        out = model(batch)
        out = out.view(-1, 100, out.data.size(-1)) #100 points per point set
        out = F.softmax(out, dim=1)
        
        outputs.extend(out.data.cpu().detach()) #storing for later access
        out = out.view(-1, out.data.size(-1))
        
        chull += [tensor.cpu().detach().numpy() for tensor in get_approx_chull(out, batch)]
    
    for batch in gt:
        gt_hulls += [tensor.cpu().detach().numpy() for tensor in batch]

    print('computing dir')
    dir_width = avg_err(chull, ptsets, in_dim=10)

    print('computing wass')
    wasserstein_dist = avg_wasserstein_nd(chull, gt_hulls)

    return  dir_width, wasserstein_dist


def main():
#     filepaths = ['depth-2-ed-64-hd-256-od-25',
# 'depth-3-ed-64-hd-256-od-25',  
# 'depth-2-ed-64-hd-512-od-25',
# 'depth-3-ed-64-hd-512-od-25', 
# 'depth-2-ed-64-hd-256-od-35',
# 'depth-2-ed-64-hd-512-od-35', 
# 'depth-3-ed-64-hd-512-od-35',
# 'depth-3-ed-64-hd-256-od-35',
# 'depth-2-ed-64-hd-256-od-50',
# 'depth-3-ed-64-hd-256-od-50',
# 'depth-2-ed-64-hd-512-od-50',
# 'depth-3-ed-64-hd-512-od-50']

    filepaths = ['depth-3-ed-64-hd-512-od-75', 'depth-3-ed-64-hd-512-od-90', 'depth-3-ed-64-hd-512-od-40']
    # thin_ellipses = np.load('../../../../../data/oren/coreset/data/thin_ellipses.npy')
    circles = np.load('../../../../../data/oren/coreset/data/10d_ball_100_shuffled.npy')
    
    datasets = [circles]

    ## Evaluating models
    errs = {}
    ellipse_errs = {}
    
    for fp in tqdm(filepaths):
        print(f'Evaluating {fp}')
        read = os.path.join('/data/oren/coreset/models/elliptical-50/ConvexHullNN/direction/', fp, 'model1_10d')
        errs[fp] = eval_ellipse(read, circles)
        # ellipse_errs[fp] = eval_ellipse(read, thin_ellipses)


    try:
        data = {
        'model': list(errs.keys()),
        'ball directional error': [val[0] for val in errs.values()],
        'ball wasserstein dist': [val[1] for val in errs.values()],
        # 'ellipse directional error': [val[0] for val in ellipse_errs.values()],
        # 'ellipse wasserstein dist': [val[1] for val in ellipse_errs.values()]
        }
        
        df = pd.DataFrame(data)
        df.to_csv('/data/oren/coreset/out/new_10d_chull_results.csv', index = False)

    except:

        out_file = open("/data/oren/coreset/out/10d_ball_errs.json", "w")
        json.dump(errs, out_file)
        out_file.close()

        # out_file = open("/data/oren/coreset/out/enc_ellipse_errs.json", "w")
        # json.dump(ellipse_errs, out_file)
        # out_file.close()

if __name__ == "__main__":
    main()