import sys
from DSA import DSA
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import MDS
import seaborn as sns
import pandas as pd
import argparse
import torch
import os
from utils_dsa import *

parser = argparse.ArgumentParser(description='DSA')
parser.add_argument('--task', type=str, default='3bff')
parser.add_argument('--n_delays', type=int, default=10)
parser.add_argument('--delay_interval', type=int, default=1)
parser.add_argument('--rank', type=int, default=200)
parser.add_argument('--rank_explained_variance', type=float, default=0.95)
parser.add_argument('--model_path', type=str, default='RNN-degeneracy/degeneracy/data')
parser.add_argument('--save_path', type=str, default='RNN-degeneracy/degeneracy/data/DSA')
parser.add_argument('--n_trials', type=int, default=256)
parser.add_argument('--n_total_seeds', type=int, default=100)
parser.add_argument('--start_seed', type=int, default=0)
parser.add_argument('--n_networks', type=int, default=100)
parser.add_argument('--idx_network', type=int, default=0)
parser.add_argument('--seed_arr', nargs='+', type=int, default=None)
parser.add_argument('--suffix', type=str, default='')

args = parser.parse_args()

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

# Load network dynamics
models = []
directory = os.path.join(args.model_path, args.task, 'hxs')

if args.seed_arr is not None:
    seed_arr = args.seed_arr
else:
    seed_arr = np.arange(args.start_seed, args.start_seed+args.n_total_seeds)

trial_idx = np.arange(args.n_trials)
for seed in seed_arr:
    try:
        file_path = os.path.join(directory, f"seed_{seed}.npy")
        hxs = np.load(file_path)
        if hxs.shape[0] > len(trial_idx):
            hxs = hxs[trial_idx]
        models.append(hxs)
    except:
        pass
    
    if len(models) == args.n_networks:
        break
          
# Compute DSA
print(f"Number of models loaded {len(models)}")
print(np.array(models).shape)

models = PCA_down_data(models, var_explained=0.99, min_dim=10)
best_delay, test_mase = optimize_n_delay(models[args.idx_network], 
                                         delay_range=np.arange(1, args.n_delays, 1), 
                                         project_state=False, 
                                         method='DMD')
print(f"Best delay: {best_delay}, MASE: {np.min(test_mase)}")

dsa = DSA(models[:args.idx_network], 
          models[args.idx_network], 
          n_delays=best_delay, 
          delay_interval=args.delay_interval, 
          verbose=True, 
          device=device)

similarities = dsa.fit_score()

# Save DSA
os.makedirs(os.path.join(args.save_path, f'{args.task}{args.suffix}'), exist_ok=True)
save_path = os.path.join(args.save_path, f'{args.task}{args.suffix}', f'network_{args.idx_network}.npy')
np.save(save_path, similarities)

print(similarities)
