import numpy as np
import pandas as pd
import libpysal

import warnings

import scipy
from scipy.spatial import distance

from matplotlib import pyplot as plt

from sklearn.datasets import make_sparse_spd_matrix

from sklearn.preprocessing import normalize
from sklearn.covariance import GraphicalLasso

from sklearn.cluster import DBSCAN
from collections import Counter

from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics.cluster import adjusted_mutual_info_score
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture

from sklearn.metrics.pairwise import haversine_distances
from math import radians

from utils import *

warnings.filterwarnings("ignore")

def load_1D(dname, fdim):
    sequence = np.loadtxt("data/{}/{}-data.txt".format(dname, dname), delimiter=",")
    ground_truth = np.loadtxt("data/{}/{}-label.txt".format(dname, dname), delimiter=",")

    features, locations = sequence[:, :fdim], sequence[:, fdim:]

    return features, locations, ground_truth


def load_2D(dname, fdim, k, metric="Euclidean"):
    sequence = np.loadtxt("data/{}/{}-data.txt".format(dname, dname), delimiter=",")
    ground_truth = np.loadtxt("data/{}/{}-label.txt".format(dname, dname), delimiter=",")

    features, locations = sequence[:, :fdim], sequence[:, fdim:]

    kd = libpysal.cg.KDTree(locations, distance_metric=metric)
    wnn = libpysal.weights.KNN(kd, k)

    nearest_pt = pd.DataFrame().from_dict(wnn.neighbors, orient="index")
    for i in range(nearest_pt.shape[1]):
        nearest_pt = nearest_pt.rename({i: f"n_pt_{i}"}, axis=1)

    return features, locations, ground_truth, nearest_pt

def fit_1D():
    pass


def fit_2D(features, nearest_pt, radius):
    glasso = GraphicalLasso(alpha=0.01, max_iter=1000)

    estimated_models = []
    valid_idx = []

    fails, warns = 0, 0

    warnings.filterwarnings("error")

    for i in range(features.shape[0]):
        data = features[nearest_pt.iloc[i, :radius], :]
        try:
            fit = glasso.fit(data)
        except Warning:
            warns += 1
            # continue
        except:
            fails += 1
            continue

        print("\rProcessed {} out of {}, {} warnings, {} failures.".format(i, features.shape[0], warns, fails), end='')

        valid_idx.append(i)
        estimated_models.append((fit.location_, fit.covariance_))

    warnings.filterwarnings("ignore")

    return valid_idx, estimated_models

def construct_matrices(cov_seq, locations, ground_truth):
    spatial_dist_matrix = construct_arc_spatial_dist_matrix(locations)
    wasser_dist_matrix = construct_wasser_dist_matrix_vectorized(cov_seq)
    cluster_match_matrix = construct_cluster_match_matrix(ground_truth)

    return spatial_dist_matrix, wasser_dist_matrix, cluster_match_matrix


dname = "inat"
radius = 30

features, locations, ground_truth, nearest_pt = load_2D(dname, 16, 50)
valid_idx, estimated_models = fit_2D(features, nearest_pt, radius)

print("Valid estimated models: {}".format(len(valid_idx)))

valid_features, valid_locations, valid_labels = features[valid_idx], locations[valid_idx], ground_truth[valid_idx]

np.savez("checkpoints/{}/{}-r{}-fitting".format(dname, dname, radius), idx=valid_idx, features=valid_features, locations=valid_locations, labels=valid_labels, mean=np.array([m[0] for m in estimated_models]), cov=np.array([m[1] for m in estimated_models]))

# data = np.load("checkpoints/{}/{}-r{}-fitting.npz".format(dname, dname, radius))

# means, covs, valid_locations, valid_labels = data["mean"], data["cov"], data["locations"], data["labels"]
# estimated_models = list(zip(means, covs))

spatial_dist_matrix, wasser_dist_matrix, cluster_match_matrix = construct_matrices(estimated_models, valid_locations, valid_labels)

np.savez("checkpoints/{}/{}-r{}-matrix".format(dname, dname, radius), spatial=spatial_dist_matrix, wasser=wasser_dist_matrix, match=cluster_match_matrix)







