import os
from google.cloud import storage
import numpy as np
from matplotlib import pyplot as plt
from utils import list_json_files, download_and_save, upload_file_to_gcs
import json
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="Report")

    parser.add_argument(
        "--algo",
        type=str,
        default="rpo",
        help="The algorithm used for training the model.",
    )
    parser.add_argument("--bucket", type=str, default="rpo-nips-bucket", help="Google Cloud Bucket to store the data.")
    parser.add_argument("--prefix", type=str, default="experiments/rpo/", help="The prefix of the files to download.")
    parser.add_argument("--local_folder", type=str, default="logs/results/rpo/", help="The local folder to save the downloaded files.")
    parser.add_argument("--latency_prefix", type=str, default="validation_rewards/rpo", help="The prefix of the latency report")
    args = parser.parse_args()
    return args

def main():
    # Parameters
    args = parse_args()

    # downloading json files
    json_files = list_json_files(args.bucket, args.prefix)
    dino_scores, clip_i, clip_t, reward = [], [], [], []
    latency_reports = list_json_files(args.bucket, args.latency_prefix)
    
    for file in json_files:
        local_path = os.path.join(args.local_folder, os.path.basename(file.name))
        if "results" in file.name:
            continue
        else:
            if os.path.exists(local_path):
                print(f"File '{local_path}' already exists. Skipping download.")
            else:
                download_and_save(args.bucket, file.name, local_path)
            with open(local_path, 'r') as f:
                data = json.load(f)
            
            dino_scores.append(data['DINO Score'])
            clip_i.append(data['CLIP I Score'])
            clip_t.append(data['CLIP T Score'])
            reward.append(data['Reward'])
    
    mean_dino, std_dino = np.mean(dino_scores), np.std(dino_scores)
    mean_clip_i, std_clip_i = np.mean(clip_i), np.std(clip_i)
    mean_clip_t, std_clip_t = np.mean(clip_t), np.std(clip_t)
    mean_reward, std_reward = np.mean(reward), np.std(reward)

    latency_list = []
    for latency_file in latency_reports:
        local_path = os.path.join(args.local_folder, os.path.basename(latency_file.name))
        if "latency" in latency_file.name:
            if os.path.exists(local_path):
                print(f"File '{local_path}' already exists. Skipping download.")
            else:
                download_and_save(args.bucket, latency_file.name, local_path)
            with open(local_path, 'r') as f:
                data = json.load(f)
            latency_list.append(data["latency"])
        else:
            continue

    print(f"mean DINO Score: {mean_dino:.3f}")
    print(f"std DINO Score: {std_dino:.3f}")
    print(f"mean CLIP I Score: {mean_clip_i:.3f}")
    print(f"sdt CLIP I Score: {std_clip_i:.3f}")
    print(f"mean CLIP T Score: {mean_clip_t:.3f}")
    print(f"std CLIP T Score: {std_clip_t:.3f}")
    print(f"mean Reward: {mean_reward:.3f}")
    print(f"std Reward: {std_reward:.3f}")

    mean_latency = np.mean(latency_list)
    mins, secs = divmod(mean_latency, 60)
    print(f"Latency report time: {int(mins)} minutes and {secs:.2f} seconds")
    
    results = {"Mean DINO Score": mean_dino, 
               "Std DINO Score": std_dino, 
               "Mean CLIP I Score": mean_clip_i, 
               "Std CLIP I Score": std_clip_i, 
               "Mean CLIP T Score": mean_clip_t, 
               "Std CLIP T Score": std_clip_t, 
               "Mean Reward": mean_reward, 
               "Std Reward": std_reward,
               "Mean latency": mean_latency}
    
    file_dir = "logs/results/" + f"{args.algo}/"
    os.makedirs(file_dir, exist_ok=True)
    filename = file_dir + f"{args.algo}_results.json"
    with open(filename, "w") as f:
        json.dump(results, f, indent=4)

    blob_destination = f"experiments/{args.algo}/results.json"
    upload_file_to_gcs(filename, args.bucket, blob_destination)

if __name__ == '__main__':
    main()