import argparse
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from math import sqrt
from scipy.stats import pearsonr
from scipy.spatial.distance import directed_hausdorff
from scipy.optimize import linear_sum_assignment
from sklearn.metrics.pairwise import euclidean_distances
from ot import emd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader

from modelnet import ModelFetcher

sns.set_context("paper", font_scale=1.5)
sns.set_style("whitegrid")

def clip_grad(model, max_norm):
    total_norm = 0
    for p in model.parameters():
        param_norm = p.grad.data.norm(2)
        total_norm += param_norm ** 2
    total_norm = total_norm ** (0.5)
    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:
        for p in model.parameters():
            p.grad.data.mul_(clip_coef)
    return total_norm

parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate')
parser.add_argument('--batch-size', type=int, default=64, help='Batch size')
parser.add_argument('--epochs', type=int, default=50, help='Number of epochs')
parser.add_argument('--hidden-dim', type=int, default=256, help='Hidden dimension size (only for MLP)')
parser.add_argument('--aggregator', default='attention', choices=['sum', 'mean', 'max', 'attention'], help='How to aggregate node representations')
parser.add_argument('--n-samples', type=int, default=100, help='Number of test samples to use for correlation')
parser.add_argument('--downsample', type=int, default=100, help='For 5000 points use 2, for 1000 use 10, for 100 use 100')
args = parser.parse_args()

class NN(nn.Module):
    def __init__(self, input_dim, hidden_dim, aggregator, n_class, dropout):
        super(NN, self).__init__()
        self.aggregator = aggregator
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, n_class)
        self.fc_att = nn.Linear(hidden_dim, hidden_dim)
        self.q = nn.Linear(hidden_dim, 1, bias=False)
        self.dropout = nn.Dropout(p=dropout)
        self.relu = nn.ReLU()
        self.model_fetcher = ModelFetcher('ModelNet40_cloud.h5', args.batch_size, args.downsample, do_standardize=True, do_augmentation=True)
                
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        if self.aggregator == 'sum':
            x = torch.sum(x, dim=1)
        elif self.aggregator == 'mean':
            x = torch.mean(x, dim=1)
        elif self.aggregator == 'max':
            x,_ = torch.max(x, dim=1)
        elif self.aggregator == 'attention':
            att = self.relu(self.fc_att(x))
            att = self.q(att)
            att = F.softmax(att, dim=0)
            x = torch.multiply(att, x)
            x = torch.sum(x, dim=1)
        x = self.relu(self.fc2(x))
        out = self.dropout(x)
        out = self.fc3(out)
        return out, x

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model = NN(input_dim=3, hidden_dim=args.hidden_dim, aggregator=args.aggregator, n_class=40, dropout=args.dropout).to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
loss_func = nn.CrossEntropyLoss()

best_val_loss = np.inf
model.train()
loss_val = float('inf')
for epoch in range(args.epochs):
    loss_all = 0
    counter = 0
    sum_acc = 0.0
    train_data = model.model_fetcher.train_data(loss_val)
    for x,_,y in train_data:
        x = torch.FloatTensor(x).to(device)
        y = torch.LongTensor(y).to(device)
        optimizer.zero_grad()
        output,_ = model(x)
        loss = loss_func(output, y)
        loss_val = loss.data.cpu().numpy()
        loss.backward()
        sum_acc += (output.max(dim=1)[1] == y).float().sum().data.cpu().numpy()
        counter += len(y)
        optimizer.step()
    train_acc = sum_acc/counter
    if epoch%10 == 10:
        tqdm.write('After epoch {0} Train Accuracy: {1:0.3f} '.format(epoch+1, train_acc))

model.eval()
loss_all = 0
counter = 0

W_fc1 = model.fc1.weight.detach().cpu()
W_fc2 = model.fc2.weight.detach().cpu()

correct = 0
x_test = list()
z_test = list()
counter = 0
for x, _, y in model.model_fetcher.test_data():
    x = torch.FloatTensor(x).to(device)
    y = torch.LongTensor(y).to(device)
    x_test.append(x.cpu())
    output,z = model(x)
    z_test.append(z.detach().cpu())
    pred = output.max(1)[1]
    correct += pred.eq(y).sum().item()
    counter += len(y)
test_acc = correct / counter
print('Test Accuracy: {:.3f} '.format(test_acc))

_,S,_ = torch.linalg.svd(W_fc1)
Lip_fc1 = torch.max(S)

_,S,_ = torch.linalg.svd(W_fc2)
Lip_fc2 = torch.max(S)

z_test = torch.cat(z_test, dim=0)

n_samples = min(args.n_samples, z_test.size(0))
idx = np.random.permutation(z_test.size(0))[:n_samples]

z_test_subset = z_test[idx,:]
dist_vec = torch.cdist(z_test_subset, z_test_subset).cpu().detach().numpy()

x_test = torch.cat(x_test, dim=0)
x_test_subset = x_test[idx,:].detach().numpy()

print('Computing EMD')
dist_x_emd = np.zeros((n_samples, n_samples))
for i in range(n_samples):
    for j in range(i+1, n_samples):
        D = euclidean_distances(x_test_subset[i], x_test_subset[j])
        a = np.ones(x_test_subset[i].shape[0])/x_test_subset[i].shape[0]
        b = np.ones(x_test_subset[j].shape[0])/x_test_subset[j].shape[0]
        F = emd(a, b, D)
        dist_x_emd[i,j] = np.sum(np.multiply(F,D))
        dist_x_emd[j,i] = dist_x_emd[i,j]

print('Computing Hausdorff')
dist_x_hausdorff = np.zeros((n_samples, n_samples))
for i in range(n_samples):
    for j in range(i+1, n_samples):
        d1 = directed_hausdorff(x_test_subset[i], x_test_subset[j])[0]
        d2 = directed_hausdorff(x_test_subset[j], x_test_subset[i])[0]
        dist_x_hausdorff[i,j] = max(d1, d2)
        dist_x_hausdorff[j,i] = dist_x_hausdorff[i,j]

print('Computing Matching Distance')
dist_x_md = np.zeros((n_samples, n_samples))
for i in range(n_samples):
    for j in range(i+1, n_samples):
        D = euclidean_distances(x_test_subset[i], x_test_subset[j])
        row_ind, col_ind = linear_sum_assignment(D)
        dist_x_md[i,j] += D[row_ind, col_ind].sum()
        if D.shape[0] > D.shape[1]:
            not_matched_inds = set(range(D.shape[0])) - set(np.unique(row_ind).tolist())
            for ind in not_matched_inds:
                dist_x_md[i,j] += np.linalg.norm(x_test_subset[i][ind,:])
        elif D.shape[1] > D.shape[0]:
            not_matched_inds = set(range(D.shape[1])) - set(np.unique(col_ind).tolist())
            for ind in not_matched_inds:
                dist_x_md[i,j] += np.linalg.norm(x_test_subset[j][ind,:])

        dist_x_md[j,i] = dist_x_md[i,j]

inds = np.triu_indices(n_samples, k=1)

M = x_test_subset.shape[1]

print(f"Lip_fc1: {Lip_fc1}")
print(f"Lip_fc2: {Lip_fc2}")

plt.figure()
g = sns.scatterplot(x=dist_x_emd[inds], y=dist_vec[inds])
if args.aggregator == 'mean':
    g.axline((0, 0), (1, 1*Lip_fc1*Lip_fc2), color='r', linestyle='--')
    plt.text(1, 50, rf"$bound = Lip(FC_1) \cdot Lip(FC_2)$", color='r', fontsize=14, rotation=45, rotation_mode='anchor', ha='center')
elif args.aggregator == 'sum':
    g.axline((0, 0), (1, M*Lip_fc1*Lip_fc2), color='b', linestyle='--')
    plt.text(0.5, 400, rf"$bound = M \cdot Lip(FC_1) \cdot Lip(FC_2)$", color='b', fontsize=14, rotation=75, rotation_mode='anchor', ha='center')
elif args.aggregator == 'max':
    g.axline((0, 0), (1, 1*M*Lip_fc1*Lip_fc2), color='g', linestyle='--')
    plt.text(0.25, 60, rf"$bound = M \cdot Lip(FC_1) \cdot Lip(FC_2)$", color='g', fontsize=14, rotation=88, rotation_mode='anchor', ha='center')
r, p = pearsonr(dist_x_emd[inds], dist_vec[inds])
plt.xlabel(r'$ d_{EMD}(X, Y) $ ', fontsize=14)
plt.ylabel(rf"$\Vert \mathbf{{v}}_{{{args.aggregator}}} - \mathbf{{u}}_{{{args.aggregator}}} \Vert$", fontsize=14)
g.set_xlim([0, np.max(dist_x_emd[inds])])
g.set_ylim([0, np.max(dist_vec[inds])+0.1])
plt.text(.8, .05, '$r={:.2f}$'.format(r), transform=g.axes.transAxes, fontsize=14)
plt.grid(True, linestyle='--', alpha=0.7) 
plt.tick_params(axis='both', which='major', labelsize=14)
plt.savefig(args.aggregator+'_emd.png', bbox_inches='tight', dpi=300)
plt.show()
 
plt.figure()
g = sns.scatterplot(x=dist_x_hausdorff[inds], y=dist_vec[inds])
if args.aggregator == 'max':
    g.axline((0, 0), (1, sqrt(args.hidden_dim)*Lip_fc1*Lip_fc2), color='r', linestyle='--')
    plt.text(0.4, 60, rf"$bound = \sqrt{{d}} \cdot Lip(FC_1) \cdot Lip(FC_2)$", color='r', fontsize=14, rotation=85, rotation_mode='anchor', ha='center')
r, p = pearsonr(dist_x_hausdorff[inds], dist_vec[inds])
plt.xlabel(r'$ d_H(X, Y) $ ', fontsize=14)
plt.ylabel(rf"$\Vert \mathbf{{v}}_{{{args.aggregator}}} - \mathbf{{u}}_{{{args.aggregator}}} \Vert$", fontsize=14)
g.set_xlim([-0.1, np.max(dist_x_hausdorff[inds])])
g.set_ylim([0, np.max(dist_vec[inds])+0.1])
plt.text(.8, .05, '$r={:.2f}$'.format(r), transform=g.axes.transAxes, fontsize=14)
plt.grid(True, linestyle='--', alpha=0.7)  # Add grid with dashed lines
plt.tick_params(axis='both', which='major', labelsize=14)
plt.savefig(args.aggregator+'_hausdorff.png', bbox_inches='tight', dpi=300)
plt.show()

plt.figure()
g = sns.scatterplot(x=dist_x_md[inds], y=dist_vec[inds])
if args.aggregator == 'sum':
    g.axline((0, 0), (1, 1*Lip_fc1*Lip_fc2), color='r', linestyle='--')
    plt.text(50, 400, rf"$bound = Lip(FC_1) \cdot Lip(FC_2)$", color='r', fontsize=14, rotation=75, rotation_mode='anchor', ha='center')
    g.set_xlim([0, np.max(dist_x_md[inds])])
elif args.aggregator == 'mean':
    g.axline((0, 0), (1, (1./M)*Lip_fc1*Lip_fc2), color='b', linestyle='--')
    plt.text(100, 50, rf"$bound = \frac{{1}}{{M}} \cdot Lip(FC_1) \cdot Lip(FC_2)$", color='b', fontsize=14, rotation=45, rotation_mode='anchor', ha='center')
    g.set_xlim([0, np.max(dist_x_md[inds])])
elif args.aggregator == 'max':
    g.axline((0, 0), (1, 1*Lip_fc1*Lip_fc2), color='g', linestyle='--')
    plt.text(20, 60, rf"$bound = Lip(FC_1) \cdot Lip(FC_2)$", color='g', fontsize=14, rotation=88, rotation_mode='anchor', ha='center')
    g.set_xlim([0, np.max(dist_x_md[inds])])
r, p = pearsonr(dist_x_md[inds], dist_vec[inds])
plt.xlabel(r'$ d_M(X, Y) $ ', fontsize=14)
plt.ylabel(rf"$\Vert \mathbf{{v}}_{{{args.aggregator}}} - \mathbf{{u}}_{{{args.aggregator}}} \Vert$", fontsize=14)
g.set_ylim([0, np.max(dist_vec[inds])+0.1])
plt.text(.8, .05, '$r={:.2f}$'.format(r), transform=g.axes.transAxes, fontsize=14)
plt.grid(True, linestyle='--', alpha=0.7)  # Add grid with dashed lines
plt.tick_params(axis='both', which='major', labelsize=14)
plt.savefig(args.aggregator+'_md.png', bbox_inches='tight', dpi=300)
plt.show()
