import pickle
from collections import Counter
from typing import Dict, List, Set, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from matplotlib.lines import Line2D

file_path = "./Crimes_-_2001_to_Present.csv"
df = pd.read_csv(file_path)
df.head()

_years = df["Year"]
year2freq = Counter(_years)
plt.xlabel("Year", fontsize="x-large")
plt.ylabel("Frequency", fontsize="x-large")

sorted_keys = sorted(year2freq.keys())
sorted_values = [year2freq[key] for key in sorted_keys]

plt.bar(sorted_keys, sorted_values)

plt.xticks(rotation=90)

plt.show()
plt.clf()

_primary_types = df["Primary Type"]
primary2freq = Counter(_primary_types)
plt.xlabel("Primary Type", fontsize="x-large")
plt.ylabel("Frequency", fontsize="x-large")
plt.yscale("log")

sorted_items = primary2freq.most_common()
primary2freq = dict(sorted_items)
plt.bar(primary2freq.keys(), primary2freq.values())

plt.xticks(rotation=90)

plt.show()
plt.clf()


def plot_primary_type_yearly(primary_type: str, df: pd.DataFrame):
    df_primary_type = df[df["Primary Type"] == primary_type]

    _years = df_primary_type["Year"]
    year2freq = Counter(_years)
    plt.xlabel("Year", fontsize="x-large")
    plt.ylabel("Frequency", fontsize="x-large")
    plt.title(f"Primary type: {primary_type}")

    sorted_keys = sorted(year2freq.keys())
    sorted_values = [year2freq[key] for key in sorted_keys]

    plt.bar(sorted_keys, sorted_values)
    plt.xticks(rotation=90)

    plt.show()
    plt.clf()


def plot_primary_type_yearly_target(primary_type: str, df: pd.DataFrame):
    df_primary_type = df[df["Primary Type"] == primary_type]
    _years = df_primary_type["Year"]
    year2freq = Counter(_years)
    print(f"Primary type: {primary_type}")

    sorted_keys = sorted(year2freq.keys())
    sorted_values = [year2freq[key] for key in sorted_keys]

    highlight_years = {2001, 2004, 2007, 2010, 2013, 2016, 2019}
    colors = ["orange" if year in highlight_years else "blue" for year in sorted_keys]

    plt.bar(sorted_keys, sorted_values, color=colors)
    plt.xticks(fontsize="xx-large")
    plt.yticks(fontsize="xx-large")

    plt.tight_layout()
    plt.show()
    plt.clf()


for primary_type in [
    "WEAPONS VIOLATION",
    "INTERFERENCE WITH PUBLIC OFFICER",
    "NARCOTICS",
    "PROSTITUTION",
]:
    plot_primary_type_yearly_target(primary_type, df)

for primary_type in set(df["Primary Type"]):
    plot_primary_type_yearly(primary_type, df)


def get_sample_xy(
    primary_type: str, df: pd.DataFrame
) -> Dict[int, Dict[str, List[int]]]:
    df_target_prime = df[df["Primary Type"] == primary_type]
    years = [2001, 2004, 2007, 2010, 2013, 2016, 2019]
    year2xy = {}
    for year in years:
        year2xy[year] = {}
        df_year = df_target_prime[df_target_prime["Year"] == year]
        _df_year = df_year[df_year.dropna(how="all")["X Coordinate"] > 0.0]
        print(len(df_year), len(_df_year))
        x = _df_year["X Coordinate"]
        y = _df_year["Y Coordinate"]
        year2xy[year]["x"] = x
        year2xy[year]["y"] = y

    return year2xy


def postprocess_sample_xy(
    year2xy: Dict[int, Dict[str, List[int]]], has_cap: bool = True
):
    cap = 200 * len(year2xy)
    freqs = []
    for year in year2xy.keys():
        x = year2xy[year]["x"]
        freq = len(x)
        freqs.append(freq)

    total_freq = sum(freqs)
    ratios = [freq / total_freq for freq in freqs]
    ratios = [1 / len(year2xy) for _ in range(len(freqs))]

    _year2xy = {}
    for year, ratio in zip(year2xy.keys(), ratios):
        x = year2xy[year]["x"]
        y = year2xy[year]["y"]

        if ratio < 1.0:
            print(f"{year}: {len(x)} -> {cap * ratio}")
        else:
            print(f"{year}: {len(x)} -> {min(len(x), cap * ratio)}")
        if has_cap:
            x = x[: int(cap * ratio)]
            y = y[: int(cap * ratio)]
        _year2xy[year] = {"x": x, "y": y}

    return _year2xy


def conduct_pca(X: np.ndarray, k: int = 2):
    Xbar = X - np.mean(X, axis=0)

    if Xbar.shape[1] == k:
        return Xbar / np.std(Xbar, axis=0)

    U, S, V = np.linalg.svd(Xbar)
    return np.dot(U[:, 0:k], np.diag(np.sqrt(S[0:k])))


def centerise_xy(year2xy: Dict[int, Dict[str, List[int]]]):
    vecs = []
    sorted_years = sorted(year2xy.keys())
    for year in sorted_years:
        x = year2xy[year]["x"]
        y = year2xy[year]["y"]
        _x = np.array(x).reshape(-1, 1)
        _y = np.array(y).reshape(-1, 1)
        vec = np.concatenate((_x, _y), axis=1)
        vecs.append(vec)
    vecs = np.vstack(vecs)

    vecs = conduct_pca(vecs)

    _year2xy = {}
    currId = 0
    for year in sorted_years:
        Num_vecs = len(year2xy[year]["x"])
        _year2xy[year] = {}
        _year2xy[year]["x"] = vecs[currId : currId + Num_vecs, 0]
        _year2xy[year]["y"] = vecs[currId : currId + Num_vecs, 1]

        assert len(_year2xy[year]["x"]) == len(_year2xy[year]["y"])
        assert len(_year2xy[year]["x"]) == len(year2xy[year]["x"])
        assert len(_year2xy[year]["y"]) == len(year2xy[year]["y"])
        currId += Num_vecs

    return _year2xy


def plot_sample_xy(primary_type: str, df: pd.DataFrame, has_cap: bool = True):
    year2xy = get_sample_xy(primary_type, df)
    _year2xy = postprocess_sample_xy(year2xy, has_cap)
    _year2xy = centerise_xy(_year2xy)
    cmap = plt.get_cmap("RdBu_r")
    T = 7
    colors = [cmap(i / (T)) for i in range(T)]

    print(f"Primary type: {primary_type}")

    for color, year in zip(colors, sorted(_year2xy.keys())):
        x = _year2xy[year]["x"]
        y = _year2xy[year]["y"]
        plt.scatter(x, y, label=f"Year: {year}", s=50, color=color)

    plt.xticks(fontsize="xx-large")
    plt.yticks(fontsize="xx-large")

    legend_handles = [
        Line2D(
            [0],
            [0],
            color="w",
            marker="o",
            label=f"Year: {year}",
            markerfacecolor=color,
            markersize=20,
        )
        for year, color in zip(sorted(_year2xy.keys()), colors)
    ]
    plt.legend(fontsize="xx-large", handles=legend_handles, loc="lower left")
    plt.tight_layout()
    plt.show()
    plt.clf()

    _colors = [colors[0], colors[3], colors[6]]
    _years = [2001, 2010, 2019]
    for color, year in zip(_colors, _years):
        x = _year2xy[year]["x"]
        y = _year2xy[year]["y"]
        plt.scatter(x, y, label=f"Year: {year}", alpha=0.4, s=50, color=color)

    plt.xticks(fontsize="xx-large")
    plt.yticks(fontsize="xx-large")

    legend_handles = [
        Line2D(
            [0],
            [0],
            color="w",
            marker="o",
            label=f"Year: {year}",
            markerfacecolor=color,
            markersize=20,
        )
        for year, color in zip(_years, _colors)
    ]
    plt.legend(fontsize="xx-large", handles=legend_handles, loc="lower left")
    plt.tight_layout()
    plt.show()
    plt.clf()

    for color, year in zip(colors, sorted(year2xy.keys())):
        plt.title(f"Primary type: {primary_type}")
        plt.xlabel("X Coordinate", fontsize="xx-large")
        plt.ylabel("Y Coordinate", fontsize="xx-large")
        x = year2xy[year]["x"]
        y = year2xy[year]["y"]
        plt.scatter(x, y, label=f"Year: {year}", alpha=0.5, color=color)
        plt.xticks(fontsize="xx-large")
        plt.yticks(fontsize="xx-large")

        plt.legend(fontsize="xx-large", loc="lower left")

        plt.show()
        plt.clf()


plot_sample_xy("WEAPONS VIOLATION", df)
plot_sample_xy("INTERFERENCE WITH PUBLIC OFFICER", df)
plot_sample_xy("PROSTITUTION", df)
plot_sample_xy("NARCOTICS", df)

plot_sample_xy("WEAPONS VIOLATION", df, has_cap=False)
plot_sample_xy("INTERFERENCE WITH PUBLIC OFFICER", df, has_cap=False)
plot_sample_xy("PROSTITUTION", df, has_cap=False)
plot_sample_xy("NARCOTICS", df, has_cap=False)

year2xy = get_sample_xy("WEAPONS VIOLATION", df)


def convert_to_year2vec(year2xy: Dict[int, Dict[str, List[int]]]):
    year2vec = {}
    for year in year2xy.keys():
        x = np.array(year2xy[year]["x"]).reshape(-1, 1)
        y = np.array(year2xy[year]["y"]).reshape(-1, 1)
        vec = np.concatenate((x, y), axis=1)
        year2vec[year] = vec

    return year2vec


def save_label2vecs(label2yearvecs):
    labels = list(label2yearvecs.keys())
    years = label2yearvecs[labels[0]].keys()
    for year in years:
        label2vecs = {}
        for label in labels:
            vecs = label2yearvecs[label][year]
            _label = "_".join(label.split())
            label2vecs[label] = []
            for vec in vecs:
                label2vecs[label].append(torch.from_numpy(vec))
        pickle.dump(label2vecs, open(f"label2vecs_{year}.pkl", "wb"))


labels = [
    "WEAPONS VIOLATION",
    "INTERFERENCE WITH PUBLIC OFFICER",
    "NARCOTICS",
    "PROSTITUTION",
]
label2yearvecs = {}
for label in labels:
    year2xy = get_sample_xy(label, df)
    year2vec = convert_to_year2vec(year2xy)
    label2yearvecs[label] = year2vec

save_label2vecs(label2yearvecs)
