import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import silhouette_score, davies_bouldin_score
import time

# 1. Parse command‐line arguments
parser = argparse.ArgumentParser()
parser.add_argument(
    "--csv_path",
    type=str,
    required=True,
    help="Path to the weatherHistory.csv file"
)
parser.add_argument(
    "--K",
    type=int,
    default=300,
    help="Number of clusters for KMeans (default: 300)"
)
args = parser.parse_args()

# 2. Start timer
start_time = time.time()

# 3. Load and preprocess data
df = pd.read_csv(args.csv_path)

df["Formatted Date"] = pd.to_datetime(df["Formatted Date"], utc=True)
df["Formatted Date"] = df["Formatted Date"].dt.tz_convert(None)
df = df.sort_values(by="Formatted Date").reset_index(drop=True)

orig_weather_labels = df["Summary"].values

mapping = {
    "Clear": "clear",
    "Breezy": "clear",
    "Dry": "clear",
    "Windy": "clear",
    "Breezy and Dry": "clear",
    "Windy and Dry": "clear",
    "Partly Cloudy": "partly cloudy",
    "Breezy and Partly Cloudy": "partly cloudy",
    "Windy and Partly Cloudy": "partly cloudy",
    "Humid and Partly Cloudy": "partly cloudy",
    "Dangerously Windy and Partly Cloudy": "partly cloudy",
    "Dry and Partly Cloudy": "partly cloudy",
    "Mostly Cloudy": "mostly cloudy",
    "Overcast": "mostly cloudy",
    "Breezy and Overcast": "mostly cloudy",
    "Windy and Overcast": "mostly cloudy",
    "Humid and Mostly Cloudy": "mostly cloudy",
    "Breezy and Mostly Cloudy": "mostly cloudy",
    "Windy and Mostly Cloudy": "mostly cloudy",
    "Dry and Mostly Cloudy": "mostly cloudy",
    "Humid and Overcast": "mostly cloudy",
    "Foggy": "foggy",
    "Windy and Foggy": "foggy",
    "Breezy and Foggy": "foggy",
    "Light Rain": "rainy",
    "Drizzle": "rainy",
    "Rain": "rainy"
}

weather_labels = np.array([mapping.get(lbl, lbl) for lbl in orig_weather_labels])

# 4. Feature engineering
df["Temp_Diff"] = df["Temperature (C)"] - df["Apparent Temperature (C)"]
df["WindX"] = df["Wind Speed (km/h)"] * np.cos(np.radians(df["Wind Bearing (degrees)"]))
df["WindY"] = df["Wind Speed (km/h)"] * np.sin(np.radians(df["Wind Bearing (degrees)"]))
df["Fog_Index"] = df["Humidity"] * (1 - df["Visibility (km)"] / df["Visibility (km)"].max())
df["Humid_Temp"] = df["Humidity"] * df["Temperature (C)"]

features = [
    "Temperature (C)",
    "Apparent Temperature (C)",
    "Humidity",
    "Wind Speed (km/h)",
    "Wind Bearing (degrees)",
    "Visibility (km)",
    "Pressure (millibars)",
    "Temp_Diff",
    "WindX",
    "WindY",
    "Fog_Index",
    "Humid_Temp"
]
X = df[features].dropna().values

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X).astype(np.float32)

# 5. Prepare incremental simulation
n_total = X_scaled.shape[0]
init_size = int(0.5 * n_total)
X_init = X_scaled[:init_size]
y_init = weather_labels[:init_size]

timestamps = df["Formatted Date"].values[:n_total]
month_dt = pd.to_datetime(timestamps, utc=True)
month_labels = [dt.strftime("%Y-%m") for dt in month_dt]

unique_months = sorted(set(month_labels))
init_last_month = month_labels[init_size - 1]
incremental_months = [m for m in unique_months if m > init_last_month]

# 6. Initial t-SNE and KMeans on initialization set
K = args.K
kmeans = KMeans(n_clusters=K, random_state=42)
kmeans_labels = kmeans.fit_predict(X_init)

tsne = TSNE(n_components=2, method="barnes_hut", random_state=42)
Y_init = tsne.fit_transform(X_init)

# 7. Define Cluster and Bi-RSNE classes
class Cluster:
    def __init__(self, high_dim_mean, low_dim_mean, sum_squares, std_dev, count=1):
        self.high_dim_mean = high_dim_mean
        self.low_dim_mean = low_dim_mean
        self.sum_squares = sum_squares
        self.std_dev = std_dev
        self.count = count

    def update(self, new_point):
        prev_count = self.count
        self.count += 1
        self.high_dim_mean = (
            (prev_count / self.count) * self.high_dim_mean
            + (1 / self.count) * new_point
        )
        self.sum_squares = (
            (prev_count / self.count) * self.sum_squares
            + (1 / self.count) * np.linalg.norm(new_point) ** 2
        )
        self.std_dev = np.sqrt(self.sum_squares - np.linalg.norm(self.high_dim_mean) ** 2)

class IncrementalTSNEBatch:
    def __init__(self, clusters, X0, Y0, y0, eta=10, max_iters=1):
        self.clusters     = clusters
        self.eta          = eta
        self.max_iters    = max_iters
        self.X_list       = [x for x in X0]
        self.Y_list       = [y for y in Y0]
        self.y_all_labels = list(y0)

    def find_nearest_clusters(self, batch_points):
        C = np.stack([c.high_dim_mean for c in self.clusters])
        dist = np.linalg.norm(batch_points[:, None, :] - C[None, :, :], axis=2)
        return np.argmin(dist, axis=1)

    def add_new_batch(self, batch_points, batch_labels):
        Xb = batch_points.astype(np.float32)
        B  = Xb.shape[0]

        nearest = self.find_nearest_clusters(Xb)

        C_low = np.stack([c.low_dim_mean for c in self.clusters])
        y_batch = C_low[nearest] + np.random.randn(B, 2).astype(np.float32) * 0.1

        for idx, x_new in zip(nearest, Xb):
            self.clusters[idx].update(x_new)

        C_high = np.stack([c.high_dim_mean for c in self.clusters])
        sigma  = np.clip(
            np.array([c.std_dev for c in self.clusters], dtype=np.float32),
            1e-3,
            None
        )

        d2h = np.sum((Xb[:, None, :] - C_high[None, :, :]) ** 2, axis=2)
        for _ in range(self.max_iters):
            P = np.exp(-d2h / (2 * sigma[None, :] ** 2))
            P /= (P.sum(axis=1, keepdims=True) + 1e-12)

            d2l = np.sum((y_batch[:, None, :] - C_low[None, :, :]) ** 2, axis=2)
            Q   = 1.0 / (1.0 + d2l)
            Q  /= (Q.sum(axis=1, keepdims=True) + 1e-12)

            coef  = 2.0 * (P - Q) / (1.0 + d2l)
            grads = np.einsum("ik,ikj->ij", coef, (y_batch[:, None, :] - C_low[None, :, :]))
            y_batch -= self.eta * grads

        for i in range(B):
            self.X_list.append(Xb[i])
            self.Y_list.append(y_batch[i])
            self.y_all_labels.append(batch_labels[i])

# 8. Initialize clusters from initialization data
clusters = []
for i in range(K):
    pts = X_init[kmeans_labels == i]
    if pts.shape[0] == 0:
        continue
    high_mean = np.mean(pts, axis=0)
    sum_sq = np.mean(np.linalg.norm(pts, axis=1) ** 2)
    std = np.std(np.linalg.norm(pts, axis=1))
    low_mean = np.mean(Y_init[kmeans_labels == i], axis=0)
    clusters.append(Cluster(high_mean, low_mean, sum_sq, std, count=pts.shape[0]))

# 9. Incremental processing by month
inc_tsne = IncrementalTSNEBatch(clusters, X_init, Y_init, y_init, eta=10, max_iters=1)

for month in incremental_months:
    idx = [
        i
        for i, m in enumerate(month_labels)
        if m == month and i >= init_size
    ]
    if len(idx) < 50:
        continue
    batch_points = X_scaled[idx]
    batch_labels = weather_labels[idx]
    inc_tsne.add_new_batch(batch_points, batch_labels)

elapsed_time = time.time() - start_time

# 10. Assemble results
X_all = np.vstack(inc_tsne.X_list)
Y_all = np.vstack(inc_tsne.Y_list)
labels_numeric = LabelEncoder().fit_transform(inc_tsne.y_all_labels)

sil_score = silhouette_score(Y_all, labels_numeric)
db_score  = davies_bouldin_score(Y_all, labels_numeric)

print(f"Total time taken: {elapsed_time:.2f} seconds")
print(f"Silhouette Score: {sil_score:.4f}")
print(f"Davies-Bouldin Index: {db_score:.4f}")

# 11. Visualization
plt.figure(figsize=(12, 10))
scatter = plt.scatter(
    Y_all[:, 0],
    Y_all[:, 1],
    c=labels_numeric,
    cmap="tab10",
    s=2,
    alpha=0.7
)
plt.title("Bi-RSNE Embedding of Weather Dataset\nColored by Reduced Weather Class")
plt.xlabel("Dimension 1")
plt.ylabel("Dimension 2")
handles, _ = scatter.legend_elements(prop="colors")
plt.legend(
    handles,
    LabelEncoder().fit(np.array(weather_labels)).classes_,
    title="Weather Types",
    bbox_to_anchor=(1.05, 1),
    loc="upper left"
)
plt.tight_layout()
plt.show()
