import os
import argparse
import numpy as np
import torch
from scipy.cluster.hierarchy import linkage, leaves_list
from scipy.spatial.distance import squareform
import matplotlib.pyplot as plt

from cca_core import robust_cca_similarity

parser = argparse.ArgumentParser(description="Compute CCA distances")
parser.add_argument('--task', type=str,               required=True)
parser.add_argument('--model_dir', type=str,         default='RNN-degeneracy/degeneracy/data')
parser.add_argument('--save_dir', type=str,          default='RNN-degeneracy/degeneracy/data')
parser.add_argument('--n_trials', type=int,          default=100)
parser.add_argument('--n_networks', type=int,        default=100)
parser.add_argument('--n_total_seeds', type=int,     default=50)
parser.add_argument('--start_seed', type=int,        default=0)
parser.add_argument('--suffix', type=str,            default='')
parser.add_argument('--seed_arr', nargs='+', type=int)
args = parser.parse_args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

hxs_dir = os.path.join(args.model_dir, args.task, 'hxs')
os.makedirs(args.save_dir, exist_ok=True)

if args.seed_arr:
    seed_arr = args.seed_arr
else:
    seed_arr = np.arange(args.start_seed, args.start_seed + args.n_total_seeds)

models = []
trial_idx = np.arange(args.n_trials)

for seed in seed_arr:
    fn = f"{args.task}_seed_{seed}.npy" if 'bff' in args.task else f"seed_{seed}.npy"
    fp = os.path.join(hxs_dir, fn)
    if not os.path.isfile(fp):
        continue

    data = np.load(fp)
    data = data[:len(trial_idx)]
    data = data.reshape(-1, data.shape[-1]).T
    models.append(data)

    if len(models) >= args.n_networks:
        break

print(f"Loaded {len(models)} models for task {args.task}")


n = len(models)
sim = np.zeros((n, n))
for i in range(n):
    for j in range(i, n):
        cca_res = robust_cca_similarity(models[i], models[j])
        mean_cca = np.mean(cca_res["cca_coef1"])
        sim[i, j] = sim[j, i] = mean_cca

dist_matrix = 1.0 - sim

# reorder via hierarchical clustering
condensed = squareform(dist_matrix, checks=False)
Z = linkage(condensed, method='average')
order = leaves_list(Z)
dist_reordered = dist_matrix[np.ix_(order, order)]

# save distance matrix
out_mat = os.path.join(args.save_dir, f"dist_matrix_{args.task}{args.suffix}.npy")
np.save(out_mat, dist_reordered)


fig, ax = plt.subplots(figsize=(8, 8))
im = ax.imshow(dist_reordered, interpolation='nearest')
ax.set_title('Hierarchical Clustering of Models\n(1 − mean CCA)')
ax.axis('off')
cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label('Distance (1 − mean CCA)')
plt.tight_layout()

out_png = os.path.join(args.save_dir, f"dist_matrix_{args.task}{args.suffix}.png")
plt.savefig(out_png, dpi=300)
plt.close(fig)
