import os
import copy
import argparse
import wandb
import torch
from config.config_loader import ConfigLoader
from environments.env_loader import parallel_env_maker
from torchrl.objectives.value import GAE
from collectors import GuidedSyncDataCollector, JumpStartSyncDataCollector, KickStartSyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from utils import generate_exp_name, get_device, create_actor_critic
from losses import WeightedClipPPOLoss
from trainers import train_ppo


def main(args, config):
    config.update(vars(args))

    if not args.debug:
        base_exp_name = generate_exp_name('transfer_iac', args.config)
        base_exp_name = f"{base_exp_name}/from_{os.path.basename(args.teacher_dir)}"

        if args.advice_mode == "energy":
            if args.random_IID:
                suffix = "_random_IID"
            elif args.fix_IID:
                suffix = "_fix_IID"
            else:
                suffix = ""
            num_transfer_s = f"_{config['num_transfer']}"
            start_frame_s = f"_s{config['start_frame']}"
            if config["limit_budget"]:
                budget_s = f"_budgetT"
            else:
                budget_s = "_budgetF"

            if config["use_least_threshold"]:
                use_least_threshold_s = "half_randomT"
            else:
                use_least_threshold_s = "half_randomF"
            exp_name = f"{config['name']}_{args.check_num}_qth{args.q_th}_lambda_{args.ex_decay_advice}{start_frame_s}{num_transfer_s}{budget_s}{use_least_threshold_s}{suffix}/{base_exp_name}"
        elif args.advice_mode == "js":
            # When using JumpStart advice mode, use init_guided_step, n_stages, and tolerance.
            exp_name = f"{config['name']}_{args.check_num}_start_{args.init_guided_step}_stages_{args.n_stages}_t_{args.tolerance}/{base_exp_name}"
        elif args.advice_mode == "kick":
            exp_name = f"kick_{config['name']}_{args.check_num}_lmb_{args.init_lambda}_ends_{args.imitation_ends}/{base_exp_name}"
        else:
            # For any other advice mode, you can default to the base exp name.
            exp_name = base_exp_name

        print(exp_name)
        config.update({'exp_dir': f"{config['teacher_dir']}/ckpts"})
        os.makedirs(f"{config['exp_dir']}/{exp_name}", exist_ok=True)
        config.update({'exp_name': exp_name})

        wandb.login()
        wandb.init(project="aa", config=config, name=exp_name)

    total_frames = args.total_frames
    if config.get("env") == "single_cook":
        test_interval = 200000
    else:
        test_interval = total_frames // 50

    frames_per_batch = config['train_batch_size']
    mini_batch_size = config['sgd_minibatch_size']
    num_sgd_iter = config['num_sgd_iter']
    num_mini_batches = frames_per_batch // mini_batch_size
    num_env = config['num_parallel_envs']
    device = get_device(config)

    actor, critic = create_actor_critic(config, device)
    t_actor, t_critic = copy.deepcopy(actor), copy.deepcopy(critic)

    t_checkpoint = torch.load(f'{args.teacher_dir}/model-{args.check_num}.pt', map_location=device)

    if isinstance(t_checkpoint, dict) and 'actor' in t_checkpoint:
        t_actor.load_state_dict(t_checkpoint['actor'])
    else:
        t_actor.load_state_dict(t_checkpoint)

    t_actor.eval()

    def compute_threshold(energies, q_th):
        if q_th < 0:
            return float("-inf")
        elif 0 <= q_th <= 1:
            sorted_vals = energies.sort()[0]
            idx = int(len(sorted_vals) * min(q_th, 0.99))
            return sorted_vals[idx]
        else:
            return float("inf")

    print(config)
    if config.get("advice_mode", "") == "js":
        print("using JumpStartSyncDataCollector")
        collector = JumpStartSyncDataCollector(
            create_env_fn=lambda: parallel_env_maker(config, num_env),
            policy=actor,
            teacher_policy=t_actor,
            teacher_kwargs={
                'init_guided_step': config.get('init_guided_step', 300),
                'n_stages': config.get('n_stages', 20),
                'tolerance': config.get('tolerance', 0),
            },
            frames_per_batch=frames_per_batch,
            total_frames=total_frames,
            device=device
        )
    elif config.get("advice_mode", "") == "kick":
        print("Using KickStartSyncDataCollector")

        imitation_ends = config.get("imitation_ends")
        collector = KickStartSyncDataCollector(
            create_env_fn=lambda: parallel_env_maker(config, num_env),
            policy=actor,
            teacher_policy=t_actor,
            teacher_kwargs={
                "init_lambda": config.get("init_lambda", 1.0),
                "imitation_ends": imitation_ends
            },
            frames_per_batch=frames_per_batch,
            total_frames=total_frames,
            device=device
        )
    else:
        print("Using Energy Guided Advisor")

        if config.get("fix_IID", False):
            print("Fixed IID")
            iid_energy_path = os.path.join(args.teacher_dir, "iid_raw_energy", f"{args.check_num}.pt")
            print(f"Loading fixed IID energy from: {iid_energy_path}")
            t_energies = torch.load(iid_energy_path, map_location=device)
        else:
            if config.get("random_IID", False):
                print("Random IID")
                iid_energy_path = os.path.join(args.teacher_dir, "iid_raw_energy", f"{args.check_num}_random.pt")
                print(f"Loading random IID energy from: {iid_energy_path}")
                t_energies = torch.load(iid_energy_path, map_location=device)
            else:
                print("IID sampled from training")
                iid_energy_path = os.path.join(args.teacher_dir, "iid_raw_energy",
                                               f"{args.check_num}_sample_images_{args.check_num}.pt")
                print(f"Loading training IID energy from: {iid_energy_path}")
                t_energies = torch.load(iid_energy_path, map_location=device)

        # Compute threshold
        t_eth = compute_threshold(t_energies, args.q_th)
        t_least = compute_threshold(t_energies, 0)
        print(f"t_eth is {t_eth} and t_least is {t_least}")

        if config['num_transfer'] == -1:
            num_transfer = config['total_frames']
        else:
            num_transfer = config['num_transfer']

        collector = GuidedSyncDataCollector(
            create_env_fn=lambda : parallel_env_maker(config, num_env),
            policy=actor,
            teacher_policy=t_actor,
            teacher_kwargs={
                'threshold': t_eth,
                'least_threshold': t_least,
                'use_least_threshold': config.get('use_least_threshold', False),
                'follow_prob': config['follow_prob'],
                'linear_decay_advice': config['linear_decay_advice'],
                'ex_decay_advice': config['ex_decay_advice'],
                'num_transfer': num_transfer,
                'start_frame': config['start_frame'],
                'fix': config["limit_budget"],
            },
            frames_per_batch=frames_per_batch,
            total_frames=total_frames,
            device=device,
            advice_reset_interval= config.get("advice_reset_interval", 5000),
            interval_advice_rate=config.get("interval_advice_rate", 0.125),
        )

    sampler = SamplerWithoutReplacement()
    data_buffer = TensorDictReplayBuffer(
        storage=LazyMemmapStorage(frames_per_batch),
        sampler=sampler,
        batch_size=mini_batch_size,
    )

    adv_module = GAE(
        gamma=config['gamma'],
        lmbda=config['lambda'],
        value_network=critic,
        average_gae=False,
    )

    loss_module = WeightedClipPPOLoss(
        actor_network=actor,
        critic_network=critic,
        clip_epsilon=config['clip_param'],
        loss_critic_type='l2',
        entropy_coef=config['entropy_coeff'],
        critic_coef=config['vf_loss_coeff'],
        vf_clip_param=config['vf_clip_param'],
        entropy_bonus=True,
        kl_coeff=config['kl_coeff'],
        normalize_advantage=config['norm_adv'],
    )

    if config["save_images"]:
        print("creating archive buffer")
        archive_buffer = TensorDictReplayBuffer(
            storage=LazyMemmapStorage(config.get("archive_buffer_size", 3000)),
            sampler=SamplerWithoutReplacement(),
        )
    else:
        archive_buffer = None

    optim = torch.optim.Adam(loss_module.parameters(), lr=config['lr'])

    test_env = parallel_env_maker(config, 10, device=device)
    test_env.eval()

    train_ppo(
        config, actor, critic, collector, data_buffer, adv_module, loss_module,
        total_frames, num_sgd_iter, num_mini_batches, optim, test_env, test_interval, device,
        archive_buffer
    )

    wandb.finish()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config",
        type=str,
        default="pacman",
        help="Configuration file to use."
    )

    parser.add_argument(
        "--teacher_dir",
        type=str,
        help="checkpoint for the pretrained model"
    )

    parser.add_argument(
        "--total_frames",
        type=int,
        default=1000000,
        help="Maximum number of training iterations."
    )

    parser.add_argument(
        "--num_transfer",
        type=int,
        default=-1,
        help="number of transfer iterations."
    )

    parser.add_argument(
        "--q_th", 
        type=float, 
        default=0.1, 
        help='energy quantile threshold'
    )


    parser.add_argument(
        "--follow_prob", 
        type=float, 
        default=1, 
        help="prob to follow teacher's advice"
    )

    parser.add_argument(
        "--linear_decay_advice", action='store_true'
    )



    parser.add_argument(
        "--ex_decay_advice",
        type=int,
        default=-1,
        help="lambda to control exponential decay steepness"
    )

    parser.add_argument(
        "--debug", action='store_true'
    )



    
    parser.add_argument(
        "--hpo",
        action="store_true",
        help="Whether to perform HPO during training."
    )


    parser.add_argument(
        "--check_num",
        type=str,
        default="best",
        help="what checkpoint to use in the teacher dir"
    )

    parser.add_argument(
        "--random_IID",
        action="store_true",
        help="use random setting to gather IID"
    )

    parser.add_argument(
        "--fix_IID",
        action="store_true",
        help="use fix setting to gather IID"
    )

    parser.add_argument(
        "--use_least_threshold",
        action="store_true",
        help="force student to take random action when OOD"
    )


    parser.add_argument(
        "--limit_budget",
        action="store_true",
        help="limit advice issue"
    )

    parser.add_argument(
        "--save_images",
        action="store_true",
        help="save sample images during training"
    )

    parser.add_argument(
        "--advice_mode",
        type=str,
        default="energy",
        choices=["js", "energy", "kick"],  # adjust choices as necessary
        help="Mode for teacher advice; use 'js' for JumpStart."
    )

    parser.add_argument(
        "--init_guided_step",
        type=int,
        default=100,
        help="Initial guided step for JumpStartSyncDataCollector"
    )

    parser.add_argument(
        "--start_frame",
        type=int,
        default=0,
        help="when do we issue advice"
    )


    parser.add_argument(
        "--n_stages",
        type=int,
        default=10,
        help="Number of stages for JumpStartSyncDataCollector"
    )

    parser.add_argument(
        "--tolerance",
        type=float,
        default=0,
        help="Tolerance (as a fraction) for stage advancement in JumpStartSyncDataCollector"
    )

    parser.add_argument(
        "--init_lambda",
        type=float,
        default=1.0,
        help="Initial distillation weight (lambda_k) for KickStart"
    )

    parser.add_argument(
        "--imitation_ends",
        type=int,
        default=1000000,
        help="Total number of frames over which to linearly decay lambda_k to zero"
    )

    args = parser.parse_args()

    config = ConfigLoader.load_config(args.config, args.hpo)
    
    main(args, config)

