import argparse
import pickle

from pathlib import Path
from functools import partial

import wandb
import torch

from configparser import ConfigParser

from behavioral_cloning.dataset import TrajectoryDataset
from behavioral_cloning.model import CnnBCetwork
from behavioral_cloning.trainer import Trainer
from behavioral_cloning.logging_bc import (
    EpochLogger,
    setup_logger_kwargs,
)
from behavioral_cloning.utils import set_seed
from behavioral_cloning.evaluate import evaluate
from envs import AtariEnvironment


def parse_args():
    conf_parser = argparse.ArgumentParser(add_help=False)
    conf_parser.add_argument(
        "-c", "--conf_file", help="Specify config file", metavar="FILE"
    )
    conf_parser.add_argument("-d", "--dataset", help="path to dataset files, pickles")
    args, remaining_argv = conf_parser.parse_known_args()

    defaults = {}
    if args.conf_file:
        config = ConfigParser()
        config.read([args.conf_file])
        defaults |= dict(config.items("DEFAULT"))

    # Dynamically add arguments from the configuration file
    parser = argparse.ArgumentParser(parents=[conf_parser])
    for key, value in defaults.items():
        # Use the key from the config file as the argument name
        parser.add_argument(f"--{key}", default=value)

    parser.set_defaults(**defaults)
    args = parser.parse_args(remaining_argv)
    args.conf_file = conf_parser.parse_known_args()[0].conf_file
    args.dataset = conf_parser.parse_known_args()[0].dataset

    # Transform args into a SectionProxy
    config_proxy = ConfigParser()
    # config_proxy.add_section('DEFAULT')
    for key, value in vars(args).items():
        config_proxy.set("DEFAULT", key, str(value))

    return config_proxy["DEFAULT"]


def main():
    default_config = parse_args()
    print(dict(default_config))

    input_size = (84, 84)
    output_size = 18
    
    seed = int(default_config["Seed"])
    set_seed(seed)

    num_train_steps = int(default_config["NumTrainSteps"])
    batch_size = int(default_config["BatchSize"])
    log_every = int(default_config["LogEvery"])
    save_every = int(default_config["SaveEvery"])
    num_eval_episodes = int(default_config["NumEvaluateEpisodes"])

    use_cuda = default_config.getboolean("UseGPU")
    epoch = int(default_config["Epoch"])
    learning_rate = float(default_config["LearningRate"])
    clip_grad_norm = float(default_config["ClipGradNorm"])

    # create dataset
    dataset_path = Path(default_config["dataset"])
    trajectories = []
    for file_path in dataset_path.iterdir():
        try:
            with open(file_path, "rb") as f:
                loaded_trajectories = pickle.load(f)
            trajectories.append(loaded_trajectories)
        except Exception as e:
            print(f"Error loading trajectories from file: {str(e)}")
    dataset = TrajectoryDataset(
        trajectories=trajectories,
        state_dim=input_size,
        act_dim=output_size,
        batch_size=batch_size,
    )

    env_id = default_config["EnvID"]
    is_render = False
    life_done = default_config.getboolean("LifeDone")
    action_prob = float(default_config["ActionProb"])
    sticky_action = default_config.getboolean("StickyAction")
    device = torch.device("cuda" if use_cuda else "cpu")

    env = AtariEnvironment(
        env_id,
        is_render,
        0,
        child_conn=None,
        life_done=life_done,
        sticky_action=sticky_action,
        p=action_prob,
        writer=None,
        use_state_loading=False,
        room_saving=False,
        should_calc_additional_metrics=True,
    )

    evaluate_episodes = partial(
        evaluate,
        env=env,
        num_eval_episodes=num_eval_episodes,
        device=device,
    )

    # create model, optimizer and trainer
    logger_kwargs = setup_logger_kwargs(default_config["RunGroup"], seed)
    logger = EpochLogger(log_to_wandb=True, **logger_kwargs)
    model = CnnBCetwork(input_size, output_size).to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    trainer = Trainer(
        model=model,
        optimizer=optim,
        dataset=dataset,
        logger=logger,
        device=device,
        eval_fns=[evaluate_episodes],
        clip_grad_norm=clip_grad_norm,
    )

    wandb.init(
        name=default_config["RunName"],
        group=default_config["RunGroup"],
        config=dict(default_config),
        reinit=True,
        resume=False,
        project="montezuma_finetuning",
        entity="<name>",
    )

    trainer.train(
        num_epochs=epoch,
        num_train_steps=num_train_steps,
        log_every=log_every,
        save_every=save_every,
    )


if __name__ == "__main__":
    main()
