import torch
from torch import nn, Tensor
import os
import re
from collections import defaultdict
from tqdm import tqdm
import pandas as pd

import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from matplotlib.path import Path
from scipy.spatial import ConvexHull
from scipy.spatial.distance import directed_hausdorff
from scipy.stats import wasserstein_distance
from geotorch.sphere import uniform_init_sphere_ as unif_sphere
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

from src.sumformer import *
from eval_utils import *



def extract_hyperparameters(filepath):
    # Define the regular expression pattern
    pattern = r"(?P<name>[a-z_]+)(?P<value>\d+)"
    
    # Use defaultdict to store each hyperparameter as a list of values
    hyperparameters = defaultdict(list)
    
    # Use finditer to get all matches in the filepath
    matches = re.finditer(pattern, filepath)
    
    # Iterate through matches and store them in the dictionary
    for match in matches:
        name = match.group("name")  # hyperparameter name
        value = int(match.group("value"))  # convert value to integer
        hyperparameters[name].append(value)
    
    return dict(hyperparameters)
    
def eval_ellipse(fp, raw_data):

    params = extract_hyperparameters(fp)

    encoder_depth = params['encoder_depth'][0]
    encoder_width = params['_hd'][0]
    encoder_od = params['_od'][0]
    proc_depth = params['_proc_depth'][0]
    proc_ed = params['_ed'][0]
    proc_hd = params['_hd'][1]
    od = params['_od'][1]
    

    ptsets = [raw_data[i][:,:-1] for i in range(len(raw_data))]
    dataloader, gt = npz_to_batches(raw_data, 128)

    model = ConvexHullEncoder(input_dim=2, encoder_depth=encoder_depth, encoder_width=encoder_width, encoder_output_dim=encoder_od,
                processor_depth=2, processor_embedding_dim=proc_ed, processor_hidden_dim=proc_hd, processor_output_dim=od)

    device = 'cpu'
    
    state_dict_path = os.path.join(fp, 'final_model.pt')
    
    model.load_state_dict(torch.load(state_dict_path), strict = False)
    model = model.to(device)
    
    chull = []
    gt_hulls = []
    outputs = []
    
    for batch in dataloader:
        batch = batch.to(device)
        out = model(batch)
        out = out.view(-1, 25, out.data.size(-1)) #25 for circle/ellipse, 50 for triangle data
        out = F.softmax(out, dim=1)
        
        outputs.extend(out.data.cpu().detach()) #storing for later access
        out = out.view(-1, out.data.size(-1))
        
        chull += [tensor.cpu().detach().numpy() for tensor in get_approx_chull(out, batch)]
    
    for batch in gt:
        gt_hulls += [tensor.cpu().detach().numpy() for tensor in batch]

    # print('computing dir')
    dir_width = avg_err(chull, ptsets)

    # print('computing wass')
    wasserstein_dist = avg_wasserstein(chull, gt_hulls)

    return  dir_width, wasserstein_dist


def main():
    filepaths = ['encoder_depth2_hd128_od64_proc_depth2_ed128_hd128_od16',
        'encoder_depth2_hd128_od64_proc_depth2_ed128_hd128_od8',
        'encoder_depth2_hd256_od64_proc_depth2_ed256_hd128_od16',
        'encoder_depth2_hd256_od64_proc_depth2_ed256_hd128_od8',
        'encoder_depth3_hd128_od64_proc_depth2_ed128_hd128_od16',
        'encoder_depth3_hd128_od64_proc_depth2_ed128_hd128_od8',
        'encoder_depth3_hd128_od64_proc_depth3_ed128_hd128_od16',
        'encoder_depth3_hd128_od64_proc_depth3_ed128_hd128_od8',
        'encoder_depth3_hd256_od64_proc_depth2_ed256_hd128_od12',
        'encoder_depth3_hd256_od64_proc_depth2_ed256_hd128_od8',
        'encoder_depth4_hd256_od64_proc_depth2_ed256_hd128_od12',
        'encoder_depth4_hd256_od64_proc_depth2_ed256_hd128_od16',
        'encoder_depth4_hd256_od64_proc_depth2_ed256_hd128_od8',
        'encoder_depth5_hd256_od64_proc_depth2_ed256_hd128_od12',
        'encoder_depth5_hd256_od64_proc_depth2_ed256_hd128_od8']

    thin_ellipses = np.load('../../../../../data/oren/coreset/data/thin_ellipses.npy')
    circles = np.load('../../../../../data/oren/coreset/data/circle_25.npy')
    
    datasets = [thin_ellipses, circles]

    ## Evaluating models
    errs = {}
    ellipse_errs = {}
    
    for fp in tqdm(filepaths):
        read = os.path.join('/data/oren/coreset/models/elliptical-50/ConvexHullEncoder/direction/', fp, 'encoder_model1_chull')
        errs[fp] = eval_ellipse(read, circles)
        ellipse_errs[fp] = eval_ellipse(read, thin_ellipses)


    try:
        data = {
        'model': list(errs.keys()),
        'circle directional error': [val[0] for val in errs.values()],
        'circle wasserstein dist': [val[1] for val in errs.values()],
        'ellipse directional error': [val[0] for val in ellipse_errs.values()],
        'ellipse wasserstein dist': [val[1] for val in ellipse_errs.values()]
        }
        
        df = pd.DataFrame(data)
        df.to_csv('/data/oren/coreset/out/full_encoder_chull_results.csv', index = False)

    except:

        out_file = open("/data/oren/coreset/out/enc_circle_errs.json", "w")
        json.dump(errs, out_file)
        out_file.close()

        out_file = open("/data/oren/coreset/out/enc_ellipse_errs.json", "w")
        json.dump(ellipse_errs, out_file)
        out_file.close()

if __name__ == "__main__":
    main()