#%%
import csv
import numpy as np
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.preprocessing import StandardScaler, normalize
import torchvision
import torchvision.transforms as transforms
from scipy.spatial.distance import cdist
import torch
import pandas as pd
import os
import gzip
import urllib.request
import networkx as nx
import numpy as np
from numba import njit


def read_csv(filepath, skip_header=True):
    with open(filepath, 'r') as f:
        reader = csv.reader(f)
        if skip_header:
            next(reader)  # Skip the header row
        return [row for row in reader]
    
device = torch.device('mps' if torch.has_mps else 'cpu')
class Data:
    def __init__(self) -> None:
        self.X = None #data indices
        self.E = None #ground set indices
        self.X_data = None #actual data 
        self.E_Data = None #actual ground set
        pass


    def init(self):
        pass
    def dist(self):
        pass

    def independent(self, S):
        return True #default is independent, override in subclasses if needed
    def __str__(self) -> str:
        return self.__class__.__name__
    
    def F(self, points, S, w=None, orig = False, get_Fis=False):
        s = 0
        fis = np.zeros(len(points))
        for i in range(len(points)):
            fis[i]=self.dist(points[i],S, orig)
            
        if w is not None:
            fis = np.dot(fis, np.array(w))
        if get_Fis: 
            return fis

        return np.sum(fis)
    

        

class BipartiteGraph(Data):
    def __init__(self, file_path = None, N=None):
        """
        Initialize the BipartiteGraph instance and load the graph from the given file.

        Args:
        file_path (str): The path to the data file.
        """
        if file_path is None:
            file_path = 'data/discogs_lstyle/out.discogs_lstyle_lstyle'
        self.file_path = file_path
        

    def init(self):
        self.graph = self.load_graph(self.file_path)
    
    def load_graph(self, file_path):
        """
        Load the graph and create a dictionary for efficient edge lookup.
        """
        self.E = set()
        self.X = set()
        graph = {}
        with open(file_path, 'r', encoding='utf-8') as file:
            for line in file:
                if line.startswith('%') or not line.strip():
                    continue

                parts = line.split()
                if len(parts) < 2:
                    continue

                left_node, right_node = int(parts[1]), int(parts[0]) #int(parts[0]), int(parts[1]) data is reversed
                if right_node not in graph:
                    graph[right_node] = set()
                graph[right_node].add(left_node)
                self.E.add(left_node)

        self.X = list(graph.keys())
        self.E = list(self.E)
        self.N = len(self.X)
        return graph

    def dist(self, right_node, left_subset, orig=False):
        """
        Check if there is an edge between a node on the right and any node in the subset of the left.
        """
        connected_nodes = self.graph.get(right_node, set())
        return 1 if connected_nodes.intersection(left_subset) else 0
    



class TaxiData(Data):
    def __init__(self, n_clusters=100, N=None, use_weights=False, file_path = None):
        if file_path is None:
            file_path = 'data/uber-raw-data-may14.csv'
        self.file_path = file_path
        self.N = N
        self.use_weights = use_weights
        self.n_clusters = n_clusters
        
    def init(self):
        positions = self.load(self.file_path)
        #print(positions.shape)
        if self.N is not None:
            rand_indices = np.random.choice(positions.shape[0],N, replace=False)
            positions = positions[rand_indices]
        self.N = len(positions)
        self.X_data = [tuple (x) for x in positions[:self.N]]
        self.X = list(range(self.N))

        self.weights = np.ones(self.N)
        self.use_weights = self.use_weights
        if self.use_weights:
            self.weights = np.exp(np.random.normal(1, np.sqrt(5), self.N))
            self.weights /= sum(self.weights)
            assert(len(self.weights[self.weights >1]) ==0)
         # Run k-means clustering
        kmeans = KMeans(n_clusters=self.n_clusters)
        kmeans.fit(self.X_data, sample_weight=self.weights)
        self.cluster_centers = kmeans.cluster_centers_
        self.E = [tuple(x) for x in self.cluster_centers]
        
        self.max_dists = {p: max(abs(self.X_data[p][0] - center[0]) + abs(self.X_data[p][1] - center[1]) for center in self.E) for p in self.X}
        
    def __str__(self) -> str:
        return self.__class__.__name__+"_use_weights"+str(self.use_weights) +"_N_"+str(self.N)

    def load(self, file_name):
        lat_lng_positions = []
        with open(file_name, 'r') as csvfile:
            csvreader = csv.reader(csvfile)
            next(csvreader)  # Skip the header row
            for row in csvreader:
                #print(row)
                lat_lng_positions.append([float(row[1]), float(row[2])])
        return np.array(lat_lng_positions)

    def visualize(self):
            # Print cluster centers
        print("Cluster Centers:")
        print(self.cluster_centers)

        # Plot the clusters
        plt.scatter(self.X[:, 1], self.X[:, 0], cmap='rainbow')
        plt.scatter(self.cluster_centers[:, 1], self.cluster_centers[:, 0], c='black', marker='x', s=100)
        plt.xlabel('Longitude')
        plt.ylabel('Latitude')
        plt.title('K-means Clustering')
        plt.show()

    def dist(self, point, centers, orig=True):
        if len(centers) ==0:
            return 0
        return self.weights[point]*(self.max_dists[point] - min(abs(self.X_data[point][0] - center[0]) + abs(self.X_data[point][1] - center[1]) for center in centers))
    

class ImageData(Data):
    def __init__(self, type="cifar100", N=None):
        if N is None:
            N = 50000
        self.N = N
        self.type = type    

    def init(self):
        if self.type == "cifar100":
            images = self.load_cifar100()
        if self.type == "fashion_mnist":
            images = self.load_fashion_mnist()
        print("computing pairwise distances")
        self.calc_cdist(images)
        print("finished computing pairwise distances")
        self.X = list(range(self.N))
        self.E = list(range(self.N))
        del images

    def __str__(self) -> str:
        return self.__class__.__name__+"_"+self.type +"_N_"+str(self.N)

    def calc_cdist(self, images):
        # Convert images to PyTorch tensors and transfer them to GPU
        images_tensor = torch.tensor(images, dtype=torch.float32).to('mps')

        # Compute the squared Euclidean distance
        r = torch.sum(images_tensor**2, 1).reshape(-1, 1)
        distances = r - 2 * torch.mm(images_tensor, images_tensor.T) + r.T
        distances = distances.clamp(min=0)  # Replace negative values with 0
        
        self.distances = torch.cat([chunk.cpu() for chunk in torch.split(distances, 1000)]).numpy()
        self.max_dists = np.max(self.distances, axis=1)
    
    def load_fashion_mnist(self):
        X, _ = fetch_openml('Fashion-MNIST', version=1, return_X_y=True, as_frame=False)
        X = X[:self.N]#random.sample(X, self.N)#X[:self.N]  # Use only the first N images for simplicity
        X = self.preprocess_data(X)
        return X

    def load_cifar10(self):
        transform = transforms.Compose([transforms.ToTensor()])
        train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
        X_train, _ = zip(*[(np.array(data[0]), data[1]) for data in train_set])
        X_test, _ = zip(*[(np.array(data[0]), data[1]) for data in test_set])
        X = np.random.choice(np.concatenate((X_train, X_test)), self.N, replace=False)#np.concatenate((X_train, X_test))[:self.N]  # Concatenate and truncate dataset
        num_samples, img_rows, img_cols, img_channels = X.shape
        X = X.reshape((num_samples, img_rows * img_cols * img_channels))
        X = self.preprocess_data(X)
        return X
    
    def load_cifar100(self):
        transform = transforms.Compose([transforms.ToTensor()])
        train_set = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
        test_set = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
        
        # Extracting only the image data
        X_train = np.array([data[0].numpy() for data in train_set])
        X_test = np.array([data[0].numpy() for data in test_set])
        
        # Concatenate the datasets
        combined = np.concatenate((X_train, X_test))
        
        # Select N random samples
        indices = np.random.choice(combined.shape[0], self.N, replace=False)
        X = combined[indices]
        
        num_samples, img_rows, img_cols, img_channels = X.shape
        X = X.reshape((num_samples, img_rows * img_cols * img_channels))
        X = self.preprocess_data(X)
        return X


    # Preprocess the data: normalize and subtract the mean
    def preprocess_data(self, X):
        scaler = StandardScaler(with_mean=True, with_std=False)
        X_scaled = scaler.fit_transform(X)
        X_scaled = normalize(X_scaled, norm='l2')
        return X_scaled


# Define the distance function d(x, x') = ||x - x'||^2
    def distance_function(self, x, x_prime):
        return np.linalg.norm(x - x_prime)**2

    def dist(self, e,A, orig=False):
        return None #shouldnt be called at all

    def dist_set(self, e,A, orig=False):
        A = np.array(A)
        e = np.array(e)
        # Use advanced indexing and broadcasting to get all distances for indices in e to A
        distances_matrix = self.distances[e[:, None], A]

        # Find the minimum distance for each index in e to all elements in A
        min_distances = np.min(distances_matrix, axis=1)

        # Normalize these minimum distances by self.N and return the result
        return self.max_dists[e] - min_distances
    
    def F(self, points, S, w=None, orig = False, get_Fis=False):
        if len(S) ==0:
            return 0
        fis = self.dist_set(points,S, orig)
        if w is not None:
            fis = np.dot(fis, np.array(w))
        
        if get_Fis: 
            return fis

        return np.sum(fis)
    
    def dist_set_gpu(self, e, A, orig=False):
        e = torch.tensor(e, device=device) if not isinstance(e, torch.Tensor) else e
        A = torch.tensor(A, device=device) if not isinstance(A, torch.Tensor) else A

        # Use advanced indexing and broadcasting to get all distances for indices in e to A
        distances_matrix = self.distances[e[:, None], A]
        # Find the minimum distance for each index in e to all elements in A
        min_distances = torch.min(distances_matrix, dim=1)[0]

        # Normalize these minimum distances by self.N and return the result
        return self.max_dists[e] - min_distances

    def F_gpu(self, points, S, w=None, orig=False, get_Fis=False):
        if len(S) == 0:
            return 0

        # Convert points and S to PyTorch tensors and move to the MPS device
        points = torch.tensor(points, device=device)
        S = torch.tensor(S, device=device)

        # Calculate F values
        fis = self.dist_set(points, S, orig)

        # If weights are provided, apply the dot product
        if w is not None:
            w = torch.tensor(w, device=device)
            fis = torch.dot(fis, w)

        if get_Fis:
            return fis.cpu().numpy()

        return torch.sum(fis).item()


@njit
def _one_cascade_numba(indptr, indices, seeds, p):
    # NumPy arrays for visited and frontier
    N = indptr.shape[0] - 1
    visited = np.zeros(N, np.uint8)
    frontier = np.empty(N, np.int64)
    next_frontier = np.empty(N, np.int64)
    # initialize frontier with seed indices
    f_size = seeds.shape[0]
    for i in range(f_size):
        s = seeds[i]
        visited[s] = 1
        frontier[i] = s
    # propagate
    while f_size > 0:
        nf_size = 0
        for i in range(f_size):
            u = frontier[i]
            for nbr_idx in range(indptr[u], indptr[u+1]):
                v = indices[nbr_idx]
                # sparse check and activation
                if visited[v] == 0 and np.random.random() < p:
                    visited[v] = 1
                    next_frontier[nf_size] = v
                    nf_size += 1
        # prepare next wave
        f_size = nf_size
        for k in range(nf_size):
            frontier[k] = next_frontier[k]
    return visited

class InfluenceMaxData(Data):
    """
    Influence Maximization on a sparse real network (Independent Cascade),
    accelerated with Numba.
    Downloads the SNAP Facebook Combined graph if missing, builds a CSR 
    representation, and uses a numba-jitted cascade.
    """
    SNAP_URL = 'https://snap.stanford.edu/data/facebook_combined.txt.gz'

    def __init__(self, file_path='data/facebook_combined.txt',
                 p=0.01, n_sim=500):
        # ensure data directory exists
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        # download and extract if needed
        if not os.path.exists(file_path):
            gz_path = file_path + '.gz'
            urllib.request.urlretrieve(self.SNAP_URL, gz_path)
            with gzip.open(gz_path, 'rb') as f_in, open(file_path, 'wb') as f_out:
                f_out.write(f_in.read())
            os.remove(gz_path)
        # load into NetworkX
        self.G = nx.read_edgelist(file_path, nodetype=int)
        self.nodes = list(self.G.nodes())
        self.node_to_idx = {v:i for i,v in enumerate(self.nodes)}
        self.N = len(self.nodes)
        # build CSR for sparse adjacency
        self._build_csr()
        # Data interface
        self.E = self.nodes[:]
        self.X = self.nodes[:]
        self.p = p
        self.n_sim = n_sim

    def _build_csr(self):
        # flatten adjacency lists into CSR format
        idx_list = []
        indptr = [0]
        for node in self.nodes:
            nbrs = list(self.G.neighbors(node))
            nbr_idx = [self.node_to_idx[n] for n in nbrs]
            idx_list.extend(nbr_idx)
            indptr.append(len(idx_list))
        self.indptr = np.array(indptr, dtype=np.int64)
        self.indices = np.array(idx_list, dtype=np.int64)

    def __str__(self):
        return f"{self.__class__.__name__}_N{self.N}_p{self.p}_sim{self.n_sim}"

    def _estimate_probs(self, S):
        # convert seed nodes to indices
        seeds = np.array([self.node_to_idx[v] for v in S], dtype=np.int64)
        counts = np.zeros(self.N, dtype=np.float64)
        # run Monte Carlo cascades with numba
        for _ in range(self.n_sim):
            visited = _one_cascade_numba(self.indptr, self.indices, seeds, self.p)
            counts += visited
        return counts / self.n_sim

    def F(self, points, S, w=None, orig=False, get_Fis=False):
        # expected spread estimation
        probs = self._estimate_probs(S)
        fis = np.array([probs[self.node_to_idx[v]] for v in points])
        if w is not None:
            fis = fis @ np.array(w)
        return fis if get_Fis else fis.sum()

    def dist(self, *args, **kwargs):
        raise NotImplementedError

    def independent(self, S):
        """
        Return True iff no two nodes in S share an edge (i.e., S is an independent set).
        """
        S_set = set(S)
        # ensure no two nodes in S are neighbors
        for v in S_set:
            for nbr in self.G.neighbors(v):
                if nbr in S_set:
                    return False
        return True