import numpy as np
import random
from collections import deque

import math
import sys
import json
import sys
import os

from Borassi import Borassi_cost
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(project_root)

from kmedianpp import WeightedKMedoids




# Global variables
intervals = []
coresets = []
# epsilon = 0.4

windowSize = 500
# windowSize = 2000
# windowSize = 1000
# windowSize = 5000

class Point:
    def __init__(self, coordinates, timestamp, color, weight=1):
        self.coordinates = coordinates  # Coordinates of points.
        self.timestamp = timestamp  
        self.weight = weight
        self.color = color

    def __getitem__(self, index):
        return self.coordinates[index]

    def __repr__(self):
        return f"Point(coords={self.coordinates}, time={self.timestamp})"
    
    def to_dict(self):
        return {
            "coordinates": self.coordinates,
            "timestamp": self.timestamp,
            "weight": self.weight,
            "color": self.color
        }

def distance(p1, p2):
    return math.sqrt(sum((a - b) ** 2 for a, b in zip(p1.coordinates, p2.coordinates)))

class BicriteriaSolution:
    def __init__(self, P, k, d):
        """
        Initializes the BicriteriaSolution using the MultMeyerson algorithm.

        Parameters:
        P (list): List of Points, where each data point is a Point class.
        k (int): Target number of centers.
        d (int): Dimension of Points
        """
        self.point = P
        # Extract the coordinates property of all Point objects
        self.coordinates_list = [point.coordinates for point in self.point]
        self.weight_list = [point.weight for point in self.point]
        self.k = k # Number of clusters
        self.dmin_map = {}
        self.dmax_map = {}
        self.d = d
        self.centers, self.assignments = self.KMedoids() # Storage clustering center, Store the allocation for each point

    def KMedoids(self):
        # Create KMedoids 
        if self.k > len(self.coordinates_list):
            kmedoids = WeightedKMedoids(n_clusters=len(self.coordinates_list), init='k-medoids++', random_state=42)
        else:
            kmedoids = WeightedKMedoids(n_clusters=self.k, init='k-medoids++', random_state=42)
        # train model
        kmedoids.fit(self.coordinates_list, sample_weight=self.weight_list)
        # get result
        labels = kmedoids.labels_
        medoid_indices = kmedoids.medoid_indices_
        centers, assignments = self.assignedTo(labels, medoid_indices)
        return centers, assignments

    def assignedTo(self, labels, medoid_indices):
        # Returns points assigned to a specific center
        centers = []
        assignments = {}
        for c_idx in medoid_indices:
            c = self.point[c_idx]
            assignments[c] = []
            self.dmin_map[c] = 0
            self.dmax_map[c] = 0
            centers.append(c) 

        l = len(labels)
        for pos in range (l): # from 0 to l-1
            p = self.point[pos]
            c = self.point[medoid_indices[labels[pos]]]
            assignments[c].append(p)
            cp_dist = distance(c, p)
            if self.dmin_map[c] == 0:
                self.dmin_map[c] = cp_dist
            else:
                self.dmin_map[c] = min(self.dmin_map[c], cp_dist)
            self.dmax_map[c] = max(self.dmax_map[c], cp_dist)
        return centers, assignments


def ring(center, r1, r2):
    # Returns a function to check if a point is within a ring defined by radii r1 and r2; r1 = 1/2 r2
    def in_ring(p):
        dist = distance(p, center)
        return r1 <= dist < r2
    return in_ring

def intersect(P_i, ring_func):
    # Returns points in P_i that are within the ring
    return [p for p in P_i if ring_func(p)]

def sortByTime(points, descending=True):
    # Sorts points by timestamp
    return sorted(points, key=lambda p: p.timestamp, reverse=descending)

def RingCoreset(Ring_i_j, T):
    """
    From "Fair Clustering on Sliding Windows" Algorithm 1: Coreset for single ring.

    Parameters:
    - Ring_i_j: A list of Points in Ring_i_j.
    
    Returns:
    - coreset_points: List of coreset data points.
    - coreset_weights: Corresponding weights.
    """
    # print("Coreset for single ring...")
    # Initialization
    n = len(Ring_i_j) 
    Ring_i_j_Coreset = []
    # T = (k * np.log(n)) / (epsilon ** 3)
    if n == 0 or n == 1:
        return Ring_i_j
    # alpha = 15
    # coreset_size = 100
    # coreset_size = 200
    # coreset_size = 500
    # coreset_size = 5000
    # T = alpha * coreset_size / (k * np.log(n))

    sum = 0 

    for i in range(n):
        pi = Ring_i_j[i]
        w_pi = pi.weight
        
        sum += w_pi  # Accumulate the weight of the current point
        prob_i = min(T * w_pi / sum, 1)  # Calculate the probability of selecting this point
        
        # Determine whether to add points to the coreset based on probability prob_i
        if random.random() <= prob_i:
            # Add point pui to the coreset
            ws_p = w_pi / prob_i  # Calculate the weight
            pi.weight = ws_p
            Ring_i_j_Coreset.append(pi)
    return Ring_i_j_Coreset

def OnlineCoreset(P, k, d, T = 50):
    sol = BicriteriaSolution(P, k, d)
    # sol = BicriteriaSolution(P, k, 2, 2)
    C = sol.centers
    l = len(C)

    coreset = []

    for i in range(l):
        P_i = sol.assignments[C[i]]

        r = sol.dmin_map[C[i]]
        a = sol.dmax_map[C[i]] * 2
        j = 0

        if r == 0:
            coreset += [C[i]]
        else:
            while r <= a:
                ring_func = ring(C[i], r/2, r)
                Ring_i_j = intersect(P_i, ring_func)
                Ring_i_j = sortByTime(Ring_i_j)  # Descending order
                if (len(Ring_i_j) > 0):
                    coreset += RingCoreset(Ring_i_j, T)
                r *= 2
                j += 1
    coreset = sortByTime(coreset)  # Descending order
    return coreset

def MergeAndReduce(P1, P2, k, d):
    # Merges two coresets and reduces them using OnlineCoreset
    # print("Merges and Reduce two coresets...")
    P = P1 + P2  # P is a list of (point, weight)
    coreset = OnlineCoreset(P, k, d)
    return coreset

def Insert(coordinates, timestamp, color, k, d):
    # print("Insert point in timestap: ", timestamp)
    p = Point(coordinates, timestamp, color, weight=1)
    intervals.append((timestamp, timestamp)) # begin, end
    coresets.append([p])

    while len(intervals) > 1:
        l2, r2 = intervals[-1] # bigger
        l1, r1 = intervals[-2] 

        # Condition ensures that intervals are adjacent and of equal length
        if r1 - l1 == r2 - l2:
            if r1 - l1 + 1 >= windowSize:
                # If interval size exceeds windowSize, discard the older coreset
                newCoreset = coresets[-1]
                coresets.pop()
                coresets.pop()
                coresets.append(newCoreset)

                intervals.pop()
                intervals.pop()
                intervals.append((l2, r2))
            else:
                # Merge the two coresets
                newCoreset = MergeAndReduce(coresets[-1], coresets[-2], k, d)  # Order matters!
                coresets.pop()
                coresets.pop()
                coresets.append(newCoreset)

                intervals.pop()
                intervals.pop()
                intervals.append((l1, r2))
        else:
            break  # Intervals are not of equal length; stop merging

def Query(current_time):
    # Returns the coreset for the window [current_time - windowSize + 1, current_time]
    coreset_ = []
    for cs in coresets:
        coreset_ += cs  # Combine all coresets

    coreset = []
    for p in coreset_:
        if current_time - windowSize < p.timestamp <= current_time:
            coreset.append((p))
    return coreset

def Query_window_coreset(current_time, coordinates, colors, k, d):
    points_in_window = []
    
    for timestamp, coord in enumerate(coordinates):
        if current_time - windowSize < timestamp <= current_time:
            points_in_window.append(Point(coord, timestamp, colors[timestamp], weight=1)) 

    coreset = OnlineCoreset(points_in_window ,k , d, 100)

    return coreset




k = int(sys.argv[1])
colors = []
points = []
weights = []
i = 0
skipped_lines = 0
try:
    p = min(int(sys.argv[2]), int(sys.argv[3]))
    q = max(int(sys.argv[2]), int(sys.argv[3]))
except:
    print("First two parameters must be non-negative integers that specify the target balance; terminating")
    sys.exit(0)

print("Loading data from input CSV file")
input_csv_filename = sys.argv[4]

for line in open(input_csv_filename).readlines():

    if len(line.strip()) == 0:
        skipped_lines += 1
        continue
    tokens = line[:-1].split(",")
    try:
        color = int(tokens[0])
        weight = int(tokens[1])
    except:
        print("Invalid color label in line", i, ", skipping")
        skipped_lines += 1
        continue
    try:
        point = [float(x) for x in tokens[2:]]
    except:
        print("Invalid point coordinates in line", i, ", skipping")
        skipped_lines += 1
        continue
    colors.append(color)
    points.append(point)
    weights.append(weight)
    i += 1

n_points = len(points)
if  n_points == 0:
    print("No successfully parsed points in input file, terminating")
    sys.exit(0)
dimension = len(points[0])

dataset = np.zeros((n_points, dimension))
for i in range(n_points):
    if len(points[i]) < dimension:
        print("Insufficient dimension in line", i+skipped_lines, ", terminating")
        sys.exit(0)
    for j in range(dimension):
        dataset[i,j] = points[i][j]

print("Number of distinct streaming data points:", n_points)
print("Dimension:", dimension)


# Sliding window model
print("Inserting...")
ith = 0
for time in range(n_points): 
    Insert(points[time], time, colors[time], k, dimension)
    if time >= windowSize - 1:
        if (time - windowSize + 1) % 350 == 0: # suppose step is 350
        # if (time - windowSize + 1) % 450 == 0: # suppose step is 450 BANK
        # if (time - windowSize + 1) % 2000 == 0: # suppose step is 2000 Athlete
        # if (time - windowSize + 1) % 1000 == 0: # suppose step is 1000 Diabetes
        # if (time - windowSize + 1) % 25000 == 0: # suppose step is 25000 census
            # Query the coreset at current time 20
            result = Query(time)
            sliding_window_coreset = Query_window_coreset(time, points, colors, k, dimension)

            ### Baseline3: Borassi.py
            Borassicost = Borassi_cost(p, q, k, sliding_window_coreset, "../../Borassi2020/output/Adult_centers.txt", ith)
            # Borassicost = Borassi_cost(p, q, k, sliding_window_coreset, "/Borassi2020/output/Bank_centers.txt", ith)
            # Borassicost = Borassi_cost(p, q, k, sliding_window_coreset, "/Borassi2020/output/Athlete_centers.txt", ith)
            # Borassicost = Borassi_cost(p, q, k, sliding_window_coreset, "/Borassi2020/output/Diabete_centers.txt", ith)
            # Borassicost = Borassi_cost(p, q, k, sliding_window_coreset, "/Borassi2020/output/Census_centers.txt", ith)




            sliding_window_coreset_json = json.dumps([p.to_dict() for p in sliding_window_coreset])
            result_json = json.dumps([p.to_dict() for p in result])
            print(f"TIME:{time}")
            # print(f"WCORES:{sliding_window_coreset_json}")
            print(f"RESULT:{sliding_window_coreset_json}###{result_json}")
            ith = ith + 1
            print(f"COST:{Borassicost}")
