import json
import os
import numpy as np
from utils import gini_coefficient, MODELS, DOCUMENTATIONS, dazzy_map, index2model, name2model
import argparse
from collections import defaultdict
from tabulate import tabulate
from scipy.stats import spearmanr, kendalltau
from scipy import stats
import matplotlib.pyplot as plt
from collections import defaultdict

def var(table_str, idx=4):
    """
    idx=4: number of tasks; 
    idx=5: reward
    """
    table_str = table_str.split('\n')[9: 14]
    value = []
    for line in table_str:
        line = line.split('|')[idx].strip()
        value.append(float(line))

    mean = np.mean(value)
    std_dev = np.std(value)
    cv = (std_dev / mean).item()
    return round(cv, 3)


def get_gini(table_str):
    table_str = table_str.split('\n')[9: 14]
    value = []
    for line in table_str:
        line = line.split('|')[5].strip()
        value.append(float(line))

    return round(gini_coefficient(value), 3)

def mean(arr):
    return round(sum(arr) / len(arr), 3)

def statistic(folder):
    ginis = []
    vars = []
    rois = []
    rewards = []
    reward_vars = []

    for file in os.listdir(folder):
        if file.endswith('.json'):
            with open(os.path.join(folder, file), 'r') as f:
                data = json.load(f)
                output = data['trajectory']['core'][-1]['content']
                gini =  output.split('\n')[4].replace('The Gini Coefficient is ', '')[:-1]
                roi =  output.split('\n')[3].replace('The overal ROI (e.g., reward / cost) is ', '')[:-1]
                vars.append(var(output, idx=4))
                ginis.append(float(gini))
                rois.append(float(roi))
                rewards.append(get_gini(output))
                reward_vars.append(var(output, idx=5))
    print(folder)
    print(mean(ginis), mean(vars),  mean(rewards), mean(reward_vars) ,mean(rois))



def get_task_by_name(names, arr):
    results = {n: [] for n in names[0]}
    for vs, ns in zip(arr, names):
        for v, n in zip(vs, ns):
            results[n].append(v)
    return [results[n][-1] for n in names[0]]

def evaluation(output_dir):
    data = [json.load(open(os.path.join(output_dir, file))) for file in os.listdir((output_dir))]

    avg_ROI = 0
    avg_Gini = 0
    avg_Reward = 0
    avg_cost = 0
    usage_inputs = 0
    usage_outputs = 0

    group_fn = lambda x, y: x.lower() == 'china' or x.lower() == 'Group 2 Country'.lower()

    group_cnt = {
        "in country": [],
        "out country": [],
        "in origin": [],
        "out origin": [],
        "in country ROI": [],
        "out country ROI": [],
    }

    name = dazzy_map(output_dir)
    config = DOCUMENTATIONS[name]
    country = config['Country']
    origin = config['Origin']
    name_fn = lambda _n: index2model[_n] if _n in index2model else name2model[_n] if _n in name2model else _n

    if 'Country' in data[0]['trajectory']['trajectory'][0]['record']:
        print('load from runtime record')
        countrys = data[0]['trajectory']['trajectory'][0]['record']['Country']
    else:
        team = data[0]['trajectory']['trajectory'][0]['record']['Name']
        countrys = [DOCUMENTATIONS[name_fn(_n)]['Country'] for _n in team]

    in_country_agents = len([1 for c in countrys if group_fn(c, country)])
    out_country_agents = len([1 for c in countrys if not group_fn(c, country)])
    surivival_agents = 0
    process_surivival_agents = []

    process_gini, process_roi, process_name = [], [], []
    final_gini, final_roi = [], []
    tasks_count = {}
    task_ids = []
    length = []
    for line in data:
        task_ids.append(line['batch_id'])
        roi =  line['trajectory']['trajectory'][-1]['record']['total_ROI']
        _num_of_tasks = line['trajectory']['trajectory'][-1]['record']['Number of tasks']
        gini = line['trajectory']['trajectory'][-1]['record']['total_Gini']

        length.append(sum(e[1] for e in line['trajectory']['usage']) / len(line['trajectory']['usage']))
        avg_Gini += gini
        avg_ROI += roi
        avg_Reward += line['trajectory']['trajectory'][-1]['record']['total_reward']
        avg_cost += line['trajectory']['trajectory'][-1]['record']['total_cost']
        usage_inputs += sum([e[0] for e in line['trajectory']['usage']])
        usage_outputs += sum([e[1] for e in line['trajectory']['usage']])
        _name = line['trajectory']['trajectory'][-1]['record']['Name']

        _ROI_per_agent = line['trajectory']['trajectory'][-1]['record']['ROI']

        final_gini.append(gini)
        final_roi.append(roi)

        process_gini.append([e['record']['total_Gini'] for e in line['trajectory']['trajectory']])
        process_roi.append([e['record']['total_ROI'] for e in line['trajectory']['trajectory']])
        process_name.append([e['name'] for e in line['trajectory']['trajectory']])

        if 'Country' in line['trajectory']['trajectory'][-1]['record']:
            _country = line['trajectory']['trajectory'][-1]['record']['Country']
        else:
            team = line['trajectory']['trajectory'][-1]['record']['Name']
            _country = [DOCUMENTATIONS[name_fn(_n)]['Country'] for _n in team]

        group_cnt["in country"].append(0)
        group_cnt["out country"].append(0)
        group_cnt["in country ROI"].append(0)
        group_cnt["out country ROI"].append(0)
        
        for _n, _num, _roi, _c in zip(_name, _num_of_tasks, _ROI_per_agent, _country):
            if _n in index2model:
                _n = index2model[_n]
            if _n in name2model:
                _n = name2model[_n]

            # in / out group preference
            if group_fn(_c, country):
                group_cnt["in country"][-1] += _num
                group_cnt["in country ROI"][-1] += _roi if _roi != 'None' else 0
            else:   
                group_cnt["out country"][-1] += _num
                group_cnt["out country ROI"][-1] += _roi if _roi != 'None' else 0


        if 'agents' in line['trajectory']['trajectory'][-1]:
            surivival_agents += line['trajectory']['trajectory'][-1]['agents'] / (in_country_agents + out_country_agents)
            process_surivival_agents.append(line['trajectory']['trajectory'][-1]['agents'] / (in_country_agents + out_country_agents))
        else:
            surivival_agents += 1
            process_surivival_agents.append(1)

        tasks_count[line['batch_id']] = {
            "in country": group_cnt["in country"][-1] / in_country_agents,
            "out country": group_cnt["out country"][-1] / out_country_agents,
            "in country ROI": group_cnt["in country ROI"][-1] / in_country_agents,
            "out country ROI": group_cnt["out country ROI"][-1] / out_country_agents,
            "ROI": roi,
            "Gini": gini,
        }
    
    prefix = lambda x: '-' * max(0, (30 - len(x))) + x[:30]
    metrics = {
        "Directory": prefix(output_dir.split('/')[-1]) + f"({len(data)})",
        "(1-Gini) * ROI": sum([(1 - g) * r for g, r in zip(final_gini, final_roi)]) / len(final_gini),
        "(1-Gini) * ROI * Survival": sum([(1 - g) * r * s for g, r, s in zip(final_gini, final_roi, process_surivival_agents)]) / len(final_gini),
        "ROI * Survival": sum([r * s for g, r, s in zip(final_gini, final_roi, process_surivival_agents)]) / len(final_gini),
        "Avg. Gini": avg_Gini / len(data),
        "Avg. ROI": avg_ROI / len(data),
        "Survival": round(surivival_agents / len(data), 3),
        "details": tasks_count,
        "final_gini": final_gini,
        "final_roi": final_roi,
        "process_gini": process_gini,
        "process_roi": process_roi,
        "task_ids": task_ids,
        "length": sum(length) / len(length)
    }
    return metrics

    
import copy
from tqdm import tqdm

def main(output_dir):
    metrics = evaluation(output_dir)
    headers = list(metrics.keys())
    values = list(metrics.values())
    table = tabulate([values], headers=headers, tablefmt="pipe")
    return metrics



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, default='./without_balance_short/api_azure_openai_gpt-5')
    args = parser.parse_args()
    result = main(args.output_dir)
