import os
import argparse
import torch
from config.config_loader import ConfigLoader
from environments.env_loader import parallel_env_maker
from losses import WeightedClipPPOLoss
from torchrl.objectives.value import GAE
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from utils import generate_exp_name, get_device, create_actor_critic
from trainers import train_ppo
import wandb


def main(args, config):
    config.update(vars(args))
    if not args.debug:
        if config["archive_buffer"]:
            s_energy = "_ELT"
        else:
            s_energy = "_ELF"
        exp_name = generate_exp_name('iac', args.config)
        exp_name = f"main_{config['name']}{s_energy}/{exp_name}"

        config.update({'exp_dir': "./ckpts"})


        os.makedirs(f"{config['exp_dir']}/{exp_name}", exist_ok=True)
        config.update({'exp_name': exp_name})
        print(config)

        wandb.login()
        wandb.init(project="aa", config=config, name=exp_name)
    print(config)

    total_frames = args.total_frames

    if config.get("env") == "single_cook":
        test_interval = 200000
    else:
        test_interval = total_frames // 50

    frames_per_batch = config['train_batch_size']
    mini_batch_size = config['sgd_minibatch_size']
    num_sgd_iter = config['num_sgd_iter']
    num_mini_batches = frames_per_batch // mini_batch_size
    num_env = config['num_parallel_envs']
    device = get_device(config)

    actor, critic = create_actor_critic(config, device)
    print(f"actor in main {actor}")
    collector = SyncDataCollector(
        create_env_fn=lambda : parallel_env_maker(config, num_env),
        policy=actor,
        frames_per_batch=frames_per_batch,
        total_frames=total_frames,
        device=device
    )

    sampler = SamplerWithoutReplacement()
    data_buffer = TensorDictReplayBuffer(
        storage=LazyMemmapStorage(frames_per_batch),
        sampler=sampler,
        batch_size=mini_batch_size,
    )

    if config["archive_buffer"] or config["save_images"]:
        print("creating archive buffer")
        archive_buffer = TensorDictReplayBuffer(
            storage=LazyMemmapStorage(config.get("archive_buffer_size", 3000)),
            sampler=SamplerWithoutReplacement(),
        )
    else:
        archive_buffer = None

    adv_module = GAE(
        gamma=config['gamma'],
        lmbda=config['lambda'],
        value_network=critic,
        average_gae=False,
    )

    if config['env'] == "metaworld":
        loss_critic_type = config['vf_loss_fn']
        kl_target = config['kl_target']
    else:
        loss_critic_type = 'l2'
        kl_target = None

    loss_module = WeightedClipPPOLoss(
        actor_network=actor,
        critic_network=critic,
        clip_epsilon=config['clip_param'],
        loss_critic_type=loss_critic_type,
        entropy_coef=config['entropy_coeff'],
        critic_coef=config['vf_loss_coeff'],
        vf_clip_param=config['vf_clip_param'],
        linear_entropy = config.get("linear_entropy", None),
        entropy_bonus=True,
        kl_target=kl_target,
        kl_coeff=config['kl_coeff'],
        normalize_advantage=config['norm_adv'],
        id_file=args.id_file,
        ood_file=args.ood_file,
        margin_in=config.get("margin_in", 12),
        margin_out=config.get("margin_out", 14),
        lambda_energy = config.get("lambda_energy", 0.0001),
        device=device
    )

    optim = torch.optim.Adam(loss_module.parameters(), lr=config['lr'])

    test_env = parallel_env_maker(config, 10, device=device)
    test_env.eval()

    train_ppo(
        config, actor, critic, collector, data_buffer, adv_module, loss_module,
        total_frames, num_sgd_iter, num_mini_batches, optim, test_env, test_interval, device,
        archive_buffer
    )

    wandb.finish()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config",
        type=str,
        default="metaworld_indep",
        help="Configuration file to use."
    )

    parser.add_argument(
        "--total_frames", 
        type=int,
        default=1000000,
        help="Maximum number of training iterations."
    )

    parser.add_argument(
        "--test_interval",
        type=int,
        default=100,
        help="number of test iterations."
    )

    parser.add_argument(
        "--debug", action='store_true'
    )
    
    parser.add_argument(
        "--hpo",
        action="store_true",
        help="Whether to perform HPO during training."
    )

    parser.add_argument(
        "--save_images",
        action="store_true",
        help="save sample images during training"
    )

    parser.add_argument(
        "--archive_buffer",
        action="store_true",
        help="use replay buffer for IID"
    )

    parser.add_argument(
        "--archive_buffer_size",
        type=int,
        default=3000,
        help="Size of the archive buffer for storing past experiences."
    )

    # Add arguments for saved `.pt` files
    parser.add_argument("--id_file", type=str,
                        default=None, help="Path to in-distribution `.pt` file.")
    parser.add_argument("--ood_file", type=str,
                        default=None, help="Path to out-of-distribution `.pt` file.")

    args = parser.parse_args()

    config = ConfigLoader.load_config(args.config, args.hpo)
    
    main(args, config)

