import torch
import numpy as np
import os
import matplotlib.pyplot as plt
import argparse
import sys
from tasks.constructor import *
from utils import *

def main():
        
    set_seed(args.seed)
    
    task = construct_task(args.task_name, args)
    
    activations = get_dynamics(args, task)
    task.model.save_measures(activations, initial_save=True)
    save_ntk(args, task, initial_save=True)
        
    dataloader = task.get_train_loader()
        
    loss = task.model.train(dataloader, 
                     epochs=args.epoch, 
                     steps_per_epoch=args.steps_per_epoch,
                     early_stopping_threshold=args.early_stopping_threshold,
                     patience=args.patience)
    
    if args.early_stopping_threshold is None or loss < args.early_stopping_threshold:
        activations = get_dynamics(args, task, check_output=False)
        task.model.save_measures(activations)
        save_ntk(args, task)

        
if __name__ == "__main__":
    print("Starting run...")
    parser = argparse.ArgumentParser(description='Tasks')
    
    # general task and network specifications
    parser.add_argument('--task_name', type=str, default='3bff')
    parser.add_argument('--n_hidden', type=int, default=128)
    parser.add_argument('--rnn_type', type=str, default='vrnn')
    parser.add_argument('--muP_param', action='store_true', default=False)
    parser.add_argument('--tau', type=float, default=0.1)
    parser.add_argument('--gain', type=float, default=0.6)
    parser.add_argument('--gamma', type=float, default=1.0)

    # training parameters
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--lr_scheduler', default=None)
    parser.add_argument('--n_batch', type=int, default=256)
    parser.add_argument('--epoch', type=int, default=500)
    parser.add_argument('--steps_per_epoch', type=int, default=128)
    parser.add_argument('--early_stopping_threshold', type=float, default=None)
    parser.add_argument('--patience', type=int, default=2)
    
    # regularizations
    parser.add_argument('--W_rank_reg', type=float, default=0.0)
    parser.add_argument('--W_l1_reg', type=float, default=0.0)
    parser.add_argument('--trainable_ratio', type=float, default=1.0)
    parser.add_argument('--small_weight_init', action='store_true', default=False)
    parser.add_argument('--init_sigma', type=float, default=0.01)
    parser.add_argument('--weight_decay', type=float, default=0.0)
    parser.add_argument('--aux_loss', action='store_true', default=False)   
    parser.add_argument('--time_shuffled', action='store_true', default=False)      # to test task complexity
    parser.add_argument('--input_reproducing', action='store_true', default=False)  # identity map
    
    # arguments for n-bits flip flop
    parser.add_argument('--n_bits', type=int, default=3)
    parser.add_argument('--p_flip', type=float, default=0.3)
    parser.add_argument('--n_time', type=int, default=100)
    parser.add_argument('--n_fps', type=int, default=2) # n-bits integration
    parser.add_argument('--bound', type=int, default=20) # n-bits integration
     
    # arguments for delayed discrimination task
    parser.add_argument('--dd_low_ts', type=int, default=2)
    parser.add_argument('--dd_high_ts', type=int, default=10)
    parser.add_argument('--dd_max_delay', type=int, default=20)
    parser.add_argument('--dd_target', type=str, default='sign', choices=['sign', 'abs_value', 'value'])
    parser.add_argument('--dd_increment', type=int, default=1)
   
    # arguments for sine wave generation task
    parser.add_argument('--num_channels', type=int, default=1)
    parser.add_argument('--freq_min', type=int, default=1)
    parser.add_argument('--freq_max', type=int, default=30)
    parser.add_argument('--freq_num', type=int, default=100)
    parser.add_argument('--dt', type=float, default=0.01)
    parser.add_argument('--sequence_length', type=int, default=100)
    parser.add_argument('--num_samples', type=int, default=2000)
    parser.add_argument('--dim_input', type=int, default=1)
    parser.add_argument('--dim_output', type=int, default=1)

    # arguments for path integration task
    parser.add_argument("--n_trials", type=int, default=1000, help="Number of trials to generate")
    parser.add_argument("--n_timesteps", type=int, default=60, help="Number of timesteps per trial")
    parser.add_argument("--v_max", type=float, default=0.4, help="Maximum speed of the agent")
    parser.add_argument("--theta_std", type=float, default=np.pi / 10, help="Standard deviation of direction increments (in radians)")
    parser.add_argument("--phi_std", type=float, default=np.pi / 10, help="Standard deviation of direction increments (in radians)")
    parser.add_argument("--speed_std", type=float, default=0.1, help="Standard deviation of speed increments")
    parser.add_argument("--x_noise_std", type=float, default=0.0001, help="Standard deviation of noise added to each position update")
    parser.add_argument("--xy_noise_std", type=float, default=0.0001, help="Standard deviation of noise added to each position update")
    parser.add_argument("--xyz_noise_std", type=float, default=0.0001, help="Standard deviation of noise added to each position update")
    parser.add_argument("--stop_mean_duration", type=float, default=30, help="Mean duration of pauses (exponential distribution)")
    parser.add_argument("--go_mean_duration", type=float, default=50, help="Mean duration of motion periods (exponential distribution)")
    parser.add_argument("--environment_size", type=float, default=10, help="Length of each side of the square environment")
    parser.add_argument('--path_dim', type=int, default=2)
    
    parser.add_argument('--save_path', type=str, default='RNN-degeneracy/degeneracy/data')
    parser.add_argument('--cuda', action='store_true', default=False)
    parser.add_argument('--device', type=str, default='cpu')
    args = parser.parse_args()

    if torch.cuda.is_available():
        args.device = 'cuda'
        args.cuda = True
        
    main()


