import os
import torch
import numpy as np
import pandas as pd
import re
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import matplotlib.tri as tri
import seaborn as sns
import torch.nn.functional as F
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
# specify device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def eps_rho_sizeplot(arr, 
                     scatter_points=True, 
                     levels=50, 
                     cmap='viridis', 
                     figsize=(8, 6), 
                     point_size=70,
                     alpha=0.8, 
                     highlight_val=None,
                     savefig_path=None):

    # Extract val, x, and y
    val = arr[:, 0]
    x = arr[:, 1]
    y = arr[:, 2]

    # Create a Triangulation object (handles irregularly spaced points)
    triang = tri.Triangulation(x, y)

    # Create a figure and axis
    fig, ax = plt.subplots(figsize=figsize)

    # Plot a filled contour using tricontourf
    contour_f = ax.tricontourf(triang, val, levels=levels, cmap=cmap)

    # Optionally overlay the scatter for actual data points
    if scatter_points:
        sc = ax.scatter(x, y, c=val, cmap=cmap, edgecolor='white', 
                        s=point_size, alpha=alpha)
        
    # Highlight contour at highlight_val (e.g., val=0.9)
    highlight_level = [highlight_val-0.03]
    highlight_contours = ax.tricontour(triang, val, levels=highlight_level,
                                       colors='black', linewidths=2, linestyles='--')
    
    # Add a colorbar to show the val scale
    cbar = fig.colorbar(contour_f, ax=ax)
    cbar.set_label('Size')

    # Add labels/title
    ax.set_xlabel(r'$\epsilon$', fontsize=25)
    ax.set_ylabel(r'$\rho$', fontsize=25)
    
    # Make the plot a bit more visually appealing
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    
    # if savefig_path is not None:
    #     plt.savefig(savefig_path, dpi=300, bbox_inches='tight')
    #     print(f"Figure saved to {savefig_path}")
    plt.savefig('wilds.png', dpi=300, bbox_inches='tight')

    return fig, ax


def parse_tensor(x):
    # If the cell looks like "tensor(0.2000)",
    # capture the numeric part and convert to float.
    match = re.match(r"tensor\((.*)\)", str(x))
    if match:
        return float(match.group(1))
    # If for some reason it doesn't match (e.g. plain numeric),
    # just return it as float directly
    return float(x)


if __name__ == '__main__':
    # Now use pandas.read_csv with a "converters" parameter:
    result = pd.read_csv(
        "wilds_result_hist.csv",
        converters={
            # For each column that might contain the "tensor(...)"
            "eps": parse_tensor,
            "rho": parse_tensor
        }
    ).to_numpy()
    
    result = result[:, [1, 2, 3, 4]]
    diffs = np.abs(result[:, 0] - 0.9)   
    closest_idx = np.argmin(diffs)
    closest_row = result[closest_idx]
    highlight_val = closest_row[1]
    
    eps_rho_sizeplot(result[:, 1:], highlight_val=highlight_val)