import json
import os
import argparse
import copy
import itertools
import pandas as pd
from collections import defaultdict

parser = argparse.ArgumentParser(description="Train Reward")
parser.add_argument(
    "--config_file",
    type=str,
    default="configs/upworthy/reward_paired.json",
    help="config file",
)
parser.add_argument("--cuda_device", type=int, default=0, help="cuda device to use")
parser.add_argument("--paired", action="store_true", help="paired training")
args = parser.parse_args()

base_config = json.load(open(args.config_file))

reward_config = base_config["reward"]
rs = reward_config["lora_config"]["r"]
seeds = reward_config["seeds"]
lrs = reward_config["training_args"]["learning_rates"]
epochs = reward_config["training_args"]["num_train_epochs"]

python_script = "train_reward_pair.py" if args.paired else "train_reward.py"

config_paths = []
train_config_paths = []
for r, seed, lr, epoch in itertools.product(rs, seeds, lrs, epochs):
    config = copy.deepcopy(base_config)
    config["exp_id"] = (
        f"{config['exp_id']}-rewards/r_{r}_seed_{seed}_lr_{lr}_epoch_{epoch}"
    )
    config["reward"]["lora_config"]["r"] = r
    config["reward"]["training_args"]["seed"] = seed
    config["reward"]["training_args"]["learning_rate"] = lr
    config["reward"]["training_args"]["num_train_epochs"] = epoch
    del config["reward"]["seeds"]
    del config["reward"]["training_args"]["learning_rates"]

    output_file = args.config_file.replace(
        ".json", f"-rewards/r_{r}_seed_{seed}_lr_{lr}_epoch_{epoch}.json"
    )
    config_paths.append(output_file)
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    json.dump(config, open(output_file, "w"), indent=4)
    train_config_paths.append(output_file)

for config_path in train_config_paths:
    print(f"Training reward with config: {config_path}")
    os.system(
        f"python {python_script} --config_file {config_path} --cuda_device {args.cuda_device}"
    )

outcomes_output_path = args.config_file.replace(".json", f"-rewards/outcomes.json")

try:
    outcomes = json.load(open(outcomes_output_path))
except FileNotFoundError:
    outcomes = {}

for config_path in config_paths:
    config = json.load(open(config_path))
    reward_training_path = config["reward"]["training_output_path"].format(**config)
    dirs = os.listdir(reward_training_path)
    dirs_nums = [int(d.split("-")[1]) for d in dirs if d.startswith("checkpoint")]
    for dir_num in dirs_nums:
        try:
            reward_path_checkpoint = os.path.join(
                reward_training_path, f"checkpoint-{dir_num}/trainer_state.json"
            )
            checkpoint = json.load(open(reward_path_checkpoint))
            log_history = checkpoint["log_history"]
            train_outcome = defaultdict(float)
            n_back = 20
            cnt = 0
            for r in range(-1 - n_back, -1):
                try:
                    if "eval_loss" in log_history[r]:
                        continue
                    for k, v in log_history[r].items():
                        train_outcome[k] += v
                    cnt += 1
                except Exception as e:
                    pass
            train_outcome = {k: v / cnt for k, v in train_outcome.items()}
            reward_path_checkpoint_dir = os.path.dirname(reward_path_checkpoint)
            outcomes[reward_path_checkpoint_dir] = {
                "valid": log_history[-1],
                "train_last": log_history[-2],
                "train": train_outcome,
            }
        except:
            outcomes[reward_path_checkpoint_dir] = {"valid": log_history[-1], "train": {"loss": None}}

json.dump(outcomes, open(outcomes_output_path, "w"), indent=4)

outcomes_list = []
for key, value in outcomes.items():
    print(f"Processing {key} with value {value}")
    key_first = key.split("/")[-4]
    print(key_first)
    r = key_first.split("_")[-7]
    seed = key_first.split("_")[-5]
    lr = key_first.split("_")[-3]
    epoch = key_first.split("_")[-1].split(".")[0]
    checkpoint = key.split("/")[-1].split("-")[-1]
    try:
        outcomes_list.append(
            {
                "seed": int(seed),
                "config_path": key,
                "lr": float(lr),
                "r": int(r),
                "epoch": int(epoch),
                "checkpoint": int(checkpoint),
                "train_loss": value["train"].get("loss"),
                "eval_loss": value["valid"]["eval_loss"],
            }
        )
    except Exception as e:
        print(e)
        print(f"r {r} seed {seed} lr {lr} epoch {epoch} failed")
outcomes_df = pd.DataFrame(outcomes_list)
outcomes_df_gb = outcomes_df.groupby(["r", "lr", "epoch", "checkpoint", "config_path"]).agg(
    {
        "train_loss": ["mean", "min", "max"],
        "eval_loss": ["mean", "min", "max"],
    }
)
print(outcomes_df_gb)
outcomes_df_gb.columns = [
    "_".join(col).strip() for col in outcomes_df_gb.columns.values
]

best_info = outcomes_df_gb[
    outcomes_df_gb["eval_loss_mean"] == outcomes_df_gb["eval_loss_mean"].min()
].index[0]

r, lr, epoch, checkpoint, config_path = best_info
seed = seeds[0]
output_file = args.config_file.replace(
    ".json", f"-rewards/r_{r}_seed_{seed}_lr_{lr}_epoch_{epoch}.json"
)
best_path = os.path.join(
    os.path.dirname(args.config_file),
    "bests",
    f"{base_config['exp_id']}_r_{r}_seed_{seed}_lr_{lr}_epoch_{epoch}_checkpoint_{checkpoint}.json",
)
config = json.load(open(output_file))
config["reward"]["model"] = config_path
config["ppo"]["reward_model_peft"] = config_path
json.dump(config, open(best_path, "w"), indent=4)
print(f"Best model saved to {best_path}")
