import os
import json
import logging
import random
import numpy as np
import torch
from typing import Union
from jsonargparse import CLI
from omegaconf import OmegaConf

from llm_router.data.registry import get_data
from llm_router.router.registry import get_router
from llm_router.router.base_router import RouterInput, RouterOutput
from llm_router.router.utils import plot_cost_score_tradeoff

logger = logging.getLogger(__name__)

def setup_logging(output_dir):
    os.makedirs(output_dir, exist_ok=True)
    logging.basicConfig(
        filename=os.path.join(output_dir, 'eval.log'), 
        filemode="w",
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        encoding='utf-8', 
        level=logging.INFO
    )
    
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["GLOBAL_RANDOM_SEED"] = str(seed)
    
def aggregate_results(results: list[dict[float, RouterOutput]]):
    def _get_final_score(router_output: RouterOutput):
        return router_output.scores[router_output.routing_id]
    
    def _get_final_cost(router_output: RouterOutput):
        return router_output.costs[router_output.routing_id]
    
    def _get_final_utility(router_output: RouterOutput, preference: float):
        return _get_final_score(router_output) - preference * _get_final_cost(router_output) 
    
    scores = []
    costs = []
    utilities = []
    for res in results:
        scores.append([_get_final_score(router_output) for preference, router_output in res.items()])
        costs.append([_get_final_cost(router_output) for preference, router_output in res.items()])
        utilities.append([_get_final_utility(router_output, preference) for preference, router_output in res.items()])
    scores = torch.tensor(scores).float()
    costs = torch.tensor(costs).float()
    utilities = torch.tensor(utilities).float()
    
    return {
        "scores": scores.mean(dim=0).tolist(),
        "costs": costs.mean(dim=0).tolist(),
        "utilities": utilities.mean(dim=0).tolist(),
    }
    
def evaluate(configs, testset, router):
    results = []
    for idx in range(len(testset)):
        router_input = RouterInput(**testset[idx])
        router_output = router.route(router_input)
        results.append(router_output)
        
    torch.save(results, os.path.join(configs.training.output_dir, "results.pth"))
    
    agg_results = aggregate_results(results)
    
    msg = {
        "router": router.name,
        "dataset": testset.name,
        "score": agg_results["scores"],
        "costs": agg_results["costs"],
        "utilities": agg_results["utilities"],
    }
    logger.info(f"\n\n{json.dumps(msg, indent=2)}\n\n")
    
    with open(os.path.join(configs.training.output_dir, "results.json"), "w") as f:
        f.write(json.dumps(msg, indent=2))
    
    plot_cost_score_tradeoff(
        costs=agg_results["costs"], scores=agg_results["scores"], 
        max_cost=testset.get_benchmark()["costs"].max().item(),
        save_path=os.path.join(configs.training.output_dir, "results.png")
    )

def main(config_file: str):
    configs = OmegaConf.load(config_file)
    setup_logging(configs.training.output_dir)
    seed_everything(configs.training.seed)
    
    trainset = get_data(configs.data.train)
    testset = get_data(configs.data.test)
    logger.info(f"trainset: {len(trainset)} testset: {len(testset)}")
    
    router = get_router(configs.router)
    router.fit(trainset, testset, configs)
    evaluate(configs, testset, router)
    
if __name__ == "__main__":
    CLI(main)