import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from eval_utils import *
from torch.utils.data import DataLoader
import os
from tqdm import tqdm
import pandas as pd
import json
from collections import defaultdict
import re
import yaml


from src.sumformer import *


def count_pts_outside(points, center, radius):
    """
    Count the number of points outside the estimated minimum enclosing ball.
    """
    distances_squared = torch.sum((points - center) ** 2, axis=1)
    radius_squared = radius ** 2
    
    # Count points whose squared distance is greater than the squared radius
    count_outside = torch.sum(distances_squared > radius_squared)
    
    return count_outside

def meb_eval(gt, preds, pts):
    center = preds[:-1].detach() #pred
    rad = preds[-1].detach()

    gt_center = gt[:-1]
    gt_rad = gt[-1]

    center_diff = np.linalg.norm(gt_center - center)
    rad_diff = abs(gt_rad - rad).item()

    num_outside = count_pts_outside(pts, center, rad).item() / len(pts)

    #return (center_diff + rad_diff).item()
    return center_diff, rad_diff, num_outside
    
def evaluate(batches, gt, model):
    errs = []
    
    for idx, batch in tqdm(enumerate(batches)):
        out = model(batch)
    
        start = 0
        for i, num in enumerate(batch.n_nodes):
            end = start + num
            ptset = batch.data[start:end]
            
            preds = out[i]
            ground_truth = gt[idx][i]
            
            errs.append(meb_eval(ground_truth, preds, ptset))
            start = end

    df = pd.DataFrame(errs, columns=["center error", "radius error", "proportion outside"])
    return df['center error'].mean(), df['radius error'].mean(), df['proportion outside'].mean()

def extract_hyperparameters(filepath):
    # Define the regular expression pattern
    pattern = r"(?P<name>[a-z-]+)(?P<value>\d+)"
    
    # Use defaultdict to store each hyperparameter as a list of values
    hyperparameters = defaultdict(list)
    
    # Use finditer to get all matches in the filepath
    matches = re.finditer(pattern, filepath)
    
    # Iterate through matches and store them in the dictionary
    for match in matches:
        name = match.group("name")  # hyperparameter name
        value = int(match.group("value"))  # convert value to integer
        hyperparameters[name].append(value)
    
    return dict(hyperparameters)    

# def define_model(params, fp):
#     model = EncoderProcessDecoder(input_dim= 3,
#     encoder_depth= params['depth'][0],
#     encoder_width= params['hd'][0],
#     encoder_output_dim= params['sum'][0],
#     processor_layer= 'ConvexHullNN', processor_configs={
#         'depth': params['depth'][1],
#         'embedding_dim': params['ed'][0],
#         'hidden_dim': params['hd'][1],
#         'input_dim': params['sum'][0],
#         'output_dim': params['od'][1]}, 
#     decoder_layer = 'PointEncoder', decoder_configs=
#         {'input_dim': params['sum'][0],
#         'embed_dim': params['ed'][1],
#         'mlp_hdim': params['mlp'][0],
#         'mlp_out_dim': params['mlpout'][0],
#         'mlp_layers': params['mlplayers'][0],
#         'phi_hdim': params['phi'][0],
#         'phi_out_dim': params['phiout'][0],
#         'phi_layers': params['philayers'][0],
#         'batchnorm': False,
#         'mean': False,
#         'use_max': False,
#         'activation': 'nn.LeakyReLU'},
#         processor_path = '/data/oren/coreset/models/elliptical-50/ConvexHullNN/direction/depth-2-ed-64-hd-256-od-25/model1_10d/final_model.pt')
    
#     state_dict_path = os.path.join(fp, 'final_model.pt')
    
#     model.load_state_dict(torch.load(state_dict_path), strict = False)
#     return model

##just have to change proc path

def define_model(params, fp):
    model = EncoderProcessDecoder(**params)
    
    state_dict_path = os.path.join(fp, 'final_model.pt')
    
    model.load_state_dict(torch.load(state_dict_path), strict = False)
    return model

# def get_models(yml_file_path):
#     with open(yml_file_path, 'r') as file:
#         data = yaml.safe_load(file)
#     print(data)
#     model_names = list(data.keys())
#     return model_names, data

def get_models(yml_file_path):
    with open(yml_file_path, 'r') as file:
        data = yaml.safe_load(file)
    
    return list(data.items())  # Return a list of (model_name, config_dict) tuples


def main():

    #transformer
    # fps = get_models('model-configs/3d_meb_5dtransformer.yml')
    # fps = get_models('model-configs/3d_meb_5dtransformer_e2e.yml')


    ## sumformers
    fps = get_models('model-configs/3d_meb_5dproc.yml')
    # fps = get_models('model-configs/3d_meb_5dproc_e2e.yml') 
    # fps = get_models('model-configs/3d_meb_random_configs.yml')


    results = []

   ### Synthetic data

    # data_name = 'meb_3d_50'
    # raw_data = json.load(open('/data/oren/coreset/data/meb_3d_50.json'))
    # train_batches, train_gt = json_to_batches(raw_data[:2400], 128) #dataset size = 3000
    # test_batches, test_gt = json_to_batches(raw_data[2400:], 128)
    # results.append( (data_name, modelname, evaluate(train_batches, train_gt, model), evaluate(test_batches, test_gt, model)))

    ### Modelnet
    data_name = 'modelnet_meb'
    # raw_data = json.load(open('/data/oren/coreset/data/subsampled_modelnet_meb.json'))

    sizes = list(range(200, 2000, 200))


    for size in sizes:
        raw_data = json.load(open(f'/data/oren/coreset/data/scaled_subsampled_{size}_modelnet_meb.json'))

        batches, gt = json_to_batches(raw_data, 128)

        # for path, params in tqdm(zip(fps, config_list)):
        for (path, params) in tqdm(fps):


            

            model_fp = os.path.join('EncoderProcessDecoder/mse', path)
            data_fp = 'min_enclosing_ball'
            modelname = model_fp[26:]


            experiment = 'meb_scaled_modelnet_5d_sumformer'
            # experiment = 'meb_scaled_modelnet_5d_sumformer_e2e'


            # inter_fp = os.path.join('/data/oren/coreset/models', data_fp, model_fp, experiment, 'record')
            fp = os.path.join('/data/oren/coreset/models', data_fp, model_fp, experiment)



            # params = extract_hyperparameters(fp)

            # model = define_model(params, fp)
            model = define_model(params, fp)
            modelname = model_fp[26:]


            result = evaluate(batches, gt, model)
            results.append((size, modelname, result))

        print(result)


      
    
    
    # df = pd.DataFrame(
    # [(d, m, ce, re, po, ce_t, re_t, po_t) for d, m, (ce, re, po), (ce_t, re_t, po_t) in results],
    # columns=["dataset", "model name", "avg train center error", "avg train radius error", "avg train proportion outside",
    #         "avg test center error", "avg test radius error", "avg test proportion outside",])
    
    df = pd.DataFrame(
    [(d, m, ce, re, po) for d, m, (ce, re, po) in results],
    columns=["ptset size", "model name",
            "avg test center error", "avg test radius error", "avg test proportion outside"])


    ### Scaled and shifted data
    # df.to_csv('/data/oren/coreset/out/5dsumformer_scaled_modelnettrain_meb_modelnet_results.csv', index = False)
    df.to_csv('/data/oren/coreset/out/5dsumformer_modelnettrain_meb_modelnet_varied_size.csv', index = False)

   
    

    ### Non-scaled data
    # df.to_csv('/data/oren/coreset/out/5dtf_meb_modelnet_results.csv', index = False)
    # df.to_csv('/data/oren/coreset/out/5dtf_e2e_meb_modelnet_results.csv', index = False)



    # df.to_csv('/data/oren/coreset/out/5dsumformer_e2e_modelnettrain_meb_modelnet_results.csv', index = False)

    # df.to_csv('/data/oren/coreset/out/5dtf_modelnettrain_meb_modelnet_results.csv', index = False)
    # df.to_csv('/data/oren/coreset/out/5dtf_e2e_modelnettrain_meb_modelnet_results.csv', index = False)



if __name__ == "__main__":
    main()