#!/usr/bin/env python
import os
import argparse
import numpy as np
import torch
import traceback

from tasks.constructor import construct_task, set_seed


def main(args):
    ood_losses = []

    # Adjust args for OOD evaluation
    if "bff" in args.task_name:
        args.n_time = 200
    elif "sine" in args.task_name:
        args.sequence_length = 200
    elif "delayed_discrimination" in args.task_name:
        # set max delay based on task name
        if "delay_10" in args.task_name:
            args.dd_max_delay = 20
        elif "delay_20" in args.task_name:
            args.dd_max_delay = 40
        elif "delay_30" in args.task_name:
            args.dd_max_delay = 60
        elif "delay_40" in args.task_name:
            args.dd_max_delay = 80
        else:
            args.dd_max_delay = 40
    elif "path_integration" in args.task_name:
        args.n_timesteps = 200

    # Generate OOD set
    task_ood = construct_task(args.task_name, args)
    X_ood, Y_ood = task_ood.generate_test_data()

    # Loop over seeds
    for seed in range(args.seed_min, args.seed_max):
        set_seed(seed)
        model_file = os.path.join(
            args.load_path, args.task_name, "weights", f"seed_{seed}.pt"
        )
        try:
            state_dict = torch.load(model_file)
            for name, param in task_ood.model.named_parameters():
                param.data = state_dict[name].clone()

            task_ood.model.to(args.device)

            # evaluate losses
            ood_losses.append(task_ood.model.get_test_loss(X_ood, Y_ood).item())

        except Exception:
            traceback.print_exc()
            continue

    # Save results
    out_dir = os.path.join(args.save_path, "OOD_eval")
    os.makedirs(out_dir, exist_ok=True)

    np.save(
        os.path.join(
            out_dir, f"{args.task_name}{args.prefix}_OOD_losses.npy"
        ),
        ood_losses,
    )

    print("OOD losses:", ood_losses)

        
if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description='Tasks')
    
    # general task and network specifications
    parser.add_argument('--task_name', type=str, default='3bff')
    parser.add_argument('--n_hidden', type=int, default=128)
    parser.add_argument('--rnn_type', type=str, default='vrnn')
    parser.add_argument('--muP_param', action='store_true', default=False)
    parser.add_argument('--tau', type=float, default=0.1)
    parser.add_argument('--gain', type=float, default=0.6)
    parser.add_argument('--gamma', type=float, default=1.0)

    # training parameters
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--seed_min', type=int, default=0)
    parser.add_argument('--seed_max', type=int, default=50)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--lr_scheduler', default=None)
    parser.add_argument('--n_batch', type=int, default=256)
    parser.add_argument('--epoch', type=int, default=20)
    parser.add_argument('--steps_per_epoch', type=int, default=128)
    parser.add_argument('--early_stopping_threshold', type=float, default=None)
    parser.add_argument('--patience', type=int, default=10)
    
    # regularizations
    parser.add_argument('--W_rank_reg', type=float, default=0.0)
    parser.add_argument('--W_l1_reg', type=float, default=0.0)
    parser.add_argument('--trainable_ratio', type=float, default=1.0)
    parser.add_argument('--small_weight_init', action='store_true', default=False)
    parser.add_argument('--init_sigma', type=float, default=0.01)
    parser.add_argument('--weight_decay', type=float, default=0.0)
    parser.add_argument('--aux_loss', action='store_true', default=False)   
    parser.add_argument('--time_shuffled', action='store_true', default=False)      # to test task complexity
    parser.add_argument('--input_reproducing', action='store_true', default=False)  # identity map
    
    # arguments for n-bits flip flop
    parser.add_argument('--n_bits', type=int, default=3)
    parser.add_argument('--p_flip', type=float, default=0.3)
    parser.add_argument('--n_time', type=int, default=100)
    parser.add_argument('--n_fps', type=int, default=2) # n-bits integration
    parser.add_argument('--bound', type=int, default=20) # n-bits integration
    
    # arguments for delayed discrimination task
    parser.add_argument('--dd_low_ts', type=int, default=2)
    parser.add_argument('--dd_high_ts', type=int, default=10)
    parser.add_argument('--dd_max_delay', type=int, default=20)
    parser.add_argument('--dd_target', type=str, default='sign', choices=['sign', 'abs_value', 'value'])
    parser.add_argument('--dd_increment', type=int, default=1)
   
    # arguments for sine wave generation task
    parser.add_argument('--num_channels', type=int, default=1)
    parser.add_argument('--freq_min', type=int, default=5)
    parser.add_argument('--freq_max', type=int, default=30)
    parser.add_argument('--freq_num', type=int, default=100)
    parser.add_argument('--dt', type=float, default=0.01)
    parser.add_argument('--sequence_length', type=int, default=100)
    parser.add_argument('--num_samples', type=int, default=2000)
    parser.add_argument('--dim_input', type=int, default=1)
    parser.add_argument('--dim_output', type=int, default=1)

    # arguments for path integration task
    parser.add_argument("--n_trials", type=int, default=1000, help="Number of trials to generate")
    parser.add_argument("--n_timesteps", type=int, default=100, help="Number of timesteps per trial")
    parser.add_argument("--v_max", type=float, default=0.4, help="Maximum speed of the agent")
    parser.add_argument("--theta_std", type=float, default=np.pi / 10, help="Standard deviation of direction increments (in radians)")
    parser.add_argument("--phi_std", type=float, default=np.pi / 10, help="Standard deviation of direction increments (in radians)")
    parser.add_argument("--speed_std", type=float, default=0.1, help="Standard deviation of speed increments")
    parser.add_argument("--x_noise_std", type=float, default=0.0001, help="Standard deviation of noise added to each position update")
    parser.add_argument("--xy_noise_std", type=float, default=0.0001, help="Standard deviation of noise added to each position update")
    parser.add_argument("--xyz_noise_std", type=float, default=0.0001, help="Standard deviation of noise added to each position update")
    parser.add_argument("--stop_mean_duration", type=float, default=30, help="Mean duration of pauses (exponential distribution)")
    parser.add_argument("--go_mean_duration", type=float, default=50, help="Mean duration of motion periods (exponential distribution)")
    parser.add_argument("--environment_size", type=float, default=10, help="Length of each side of the square environment")
    parser.add_argument('--path_dim', type=int, default=2)
    
    parser.add_argument('--prefix', type=str, default='_len')
    parser.add_argument('--load_path', type=str, default='RNN-degeneracy/degeneracy/data')
    parser.add_argument('--save_path', type=str, default='RNN-degeneracy/degeneracy/data')
    parser.add_argument('--cuda', action='store_true', default=False)
    parser.add_argument('--device', type=str, default='cpu')
    args = parser.parse_args()

    if torch.cuda.is_available():
        args.device = 'cuda'
        args.cuda = True
        
    main()

