import numpy as np
import matplotlib.pyplot as plt

def sac_arguments(parser):
    parser.add_argument('--task-id', type=int, default=1, help="Task ID to train on")
    parser.add_argument('--horizon', type=int, default=100, help="Max episode steps")
    parser.add_argument('--lr', type=float, default=3e-4, help="Learning rate for SAC")
    parser.add_argument('--batch-size', type=int, default=256, help="Batch size")
    parser.add_argument('--timesteps', type=int, default=int(1e6), help="Number of training steps")
    parser.add_argument('--eval-interval', type=int, default=10000, help="Evaluation interval")

def add_dataset_args(parser):
    parser.add_argument('--env', type=str, default='darkroom', choices=['ml1_pick_place', 'cheetah_vel', 'darkroom'], help="Environment to use")
    parser.add_argument('--policy_quality', type=str, default='80', help="Policy quality: percentage of optimal policy")
    parser.add_argument('--n_samples', type=int, default=1, help="Number of samples to calculate exp reward")
    parser.add_argument('--n_trails', type=int, default=1, help="Number of trails for each env")

def add_ml1_pick_place_dataset_args(parser):
    parser.add_argument('--model_ckpt_path', type=str, default='/hpc/home/mg585/baselines_ICLR/ml1-pick-place-ckpts', help="Path to model checkpoints")
    parser.add_argument('--horizon', type=int, default=100, help="Max episode steps")
    parser.add_argument('--gamma', type=float, default=0.8, help="Discount factor")
    parser.add_argument('--train_tasks', type=list, default=[1,2,3,4,5,7,8,10,11,12,13,14,17,18,19], help="Tasks to train on")
    # parser.add_argument('--test_tasks', type=list, default=[6,9,15,16,20], help="Tasks to test on")
    parser.add_argument('--test_tasks', type=list, default=[], help="Tasks to test on")
    parser.add_argument('--state_dim', type=int, required=False, default=39, help="State dimension")
    parser.add_argument('--action_dim', type=int, required=False, default=4, help="Action dimension")    
#  "train_tasks": [
    #  0,1,2,3,4,5,6,7,8,9,
    #  10,12,13,14,15,16,18,19,
    #  20,21,22,23,24,26,27,28,29,
    #  30,31,32,33,34,35,37,38,39,
    #  40,42,43,44,45,46,47,48,49],
#  "test_tasks": [11,17,25,36,41]
def add_cheetah_vel_dataset_args(parser):
    parser.add_argument('--model_ckpt_path', type=str, default='/hpc/home/mg585/baselines_ICLR/cheetah-vel-ckpts', help="Path to model checkpoints")
    parser.add_argument('--horizon', type=int, default=100, help="Max episode steps")
    parser.add_argument('--gamma', type=float, default=0.8, help="Discount factor")
    parser.add_argument('--train_tasks', type=list, default=[0, 1, 3, 4, 5, 6, 8, 9, 10, 
                                                                11, 12, 13, 14, 16, 17, 18, 
                                                                19, 20, 21, 22, 24, 25, 
                                                                27, 28, 29, 30, 31, 32, 33, 34, 
                                                                35, 36, 37, 38, 39], 
                                                            help="Tasks to train on") 
    parser.add_argument('--test_tasks', type=list, default=[2, 7, 15, 23, 26], help="Tasks to test on")
    parser.add_argument('--state_dim', type=int, required=False, default=20, help="State dimension")
    parser.add_argument('--action_dim', type=int, required=False, default=6, help="Action dimension")   

def add_darkroom_dataset_args(parser):
    parser.add_argument('--p', type=float, default=0.3, help="Probability of taking random action, policy quality")
    parser.add_argument('--darkroom_goal_path', type=str, default='envs/darkroom/darkroom_goals.npy', help="Path to darkroom goals")
    parser.add_argument('--horizon', type=int, default=100, help="Max episode steps")
    parser.add_argument('--dim', type=int, default=10, help="Dimension of the darkroom")
    parser.add_argument('--gamma', type=float, default=0.8, help="Discount factor")
    parser.add_argument('--state_dim', type=int, required=False, default=2, help="State dimension")
    parser.add_argument('--action_dim', type=int, required=False, default=5, help="Action dimension")     
    parser.add_argument('--train_tasks', type=list, default=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 
                                                            11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 
                                                            21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 
                                                            31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 
                                                            41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
                                                            51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
                                                            61, 62, 63, 64, 65, 66, 67, 68, 69, 70,
                                                            71, 72, 73, 74, 75, 76, 77, 78, 79
                                                            ], help="Tasks to train on") 
    parser.add_argument('--test_tasks', type=list, default=[80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 
                                                            90, 91, 92, 93, 94, 95, 96, 97, 98, 99], help="Tasks to test on")

def add_model_args(parser):
    parser.add_argument("--embd", type=int, required=False,
                        default=256, help="Embedding size")
    parser.add_argument("--head", type=int, required=False,
                        default=1, help="Number of heads")
    parser.add_argument("--layer", type=int, required=False,
                        default=6, help="Number of layers")
    parser.add_argument("--lr", type=float, required=False,
                        default=1e-3, help="Learning Rate")
    parser.add_argument("--dropout", type=float,
                        required=False, default=0, help="Dropout")
    parser.add_argument('--type', type=str, required=False, default='Q', help="Type of model (Q/V)")


def add_train_args(parser):
    parser.add_argument('--device', type=str, default='cuda', help="Device to run on (cuda/cpu)")
    parser.add_argument("--num_epochs", type=int, required=False,
                        default=40, help="Number of epochs")
    parser.add_argument("--Heps", type=int, required=False, default=10, help="Number of episodes")
    parser.add_argument("--eta", type=float, required=False, default=0.5, help="Eta")
    parser.add_argument("--use_rewarder", type=bool, required=False, default=False, help="Use rewarder")
    parser.add_argument("--reweight", action='store_true', default=False, help="Reweight")
    parser.add_argument("--Q_path", type=str, required=False, default='none', help="Path to Q function")
    parser.add_argument("--V_path", type=str, required=False, default='none', help="Path to V function")



def build_data_filename(mode, args):
    if args.env == 'darkroom':
        filename = f'datasets/{args.env}/{args.env}_H{args.horizon}_p{args.p}_{mode}.pkl'
    elif args.env == 'ml1_pick_place':
        filename = f'datasets/{args.env}/{args.env}_H{args.horizon}_q{args.policy_quality}_n{args.n_trails}_{mode}.pkl'
    elif args.env == 'cheetah_vel':
        filename = f'datasets/{args.env}/{args.env}_H{args.horizon}_q{args.policy_quality}_n{args.n_trails}_{mode}_normalized.pkl'
    return filename


def draw_pred(pred, gt):
    x = np.linspace(0, pred.shape[1], pred.shape[1])

    # Create a figure with two subplots
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))

    # First subplot: difference between prediction and ground truth
    diff = np.abs(pred - gt).squeeze()
    axs[0].plot(x, np.mean(diff, axis=0), label='Mean Absolute Difference')
    axs[0].fill_between(x, np.mean(diff, axis=0) - np.std(diff, axis=0), 
                        np.mean(diff, axis=0) + np.std(diff, axis=0), alpha=0.3, label='Std Dev')
    axs[0].set_xlabel('Time step')
    axs[0].set_ylabel('Difference')
    axs[0].set_title('Prediction - Ground Truth Difference')
    axs[0].legend()

    # Second subplot: mean trajectories of prediction and ground truth
    axs[1].plot(x, np.mean(pred, axis=0), label='Prediction')
    axs[1].plot(x, np.mean(gt, axis=0), label='Ground Truth')
    axs[1].set_xlabel('Time step')
    axs[1].set_ylabel('Value')
    axs[1].set_title('Prediction vs Ground Truth')
    axs[1].legend()

    # Adjust layout to avoid overlap
    plt.tight_layout()

    return fig