import wandb
import os
import pandas as pd

from fairret import *
from .constants import LOGS_DIR, ENTITY, PROJECT, COLUMN_RENAMES, RESULTS_DIR


def build_fairret(stat=None, loss_cfg=None):
    loss_cls = LOSS_BY_NAME[loss_cfg.pop('name')]
    if stat is None:
        fairret = loss_cls(**loss_cfg)
    else:
        if isinstance(stat, str):
            stat = STATISTIC_BY_NAME[stat]()
        elif isinstance(stat, list):
            stat = StackLinearFractionalStatistic([STATISTIC_BY_NAME[s]() for s in stat])
        fairret = loss_cls(stat, **loss_cfg)

    return fairret


def unnest_dict(d, sep='.'):
    if not isinstance(d, dict):
        return d

    new_d = {}
    for key, value in d.items():
        if isinstance(value, dict):
            for sub_key, sub_value in unnest_dict(value, sep).items():
                new_d[f"{key}{sep}{sub_key}"] = sub_value
        else:
            new_d[key] = value
    return new_d


def get_sweeps_df(sweep_ids, api=None):
    if api is None:
        api = wandb.Api()

    all_dfs = []
    for sweep_id in sweep_ids:
        sweep_path = os.path.join(RESULTS_DIR, f"{sweep_id}.csv")
        try:
            df = pd.read_csv(sweep_path)
        except FileNotFoundError:
            df = download_sweep_results(sweep_id, api)
        all_dfs.append(df)
    sweeps_df = pd.concat(all_dfs)
    return sweeps_df


def download_sweep_results(sweep_id, api=None):
    if api is None:
        api = wandb.Api()

    series = []
    for run in api.sweep(f"{ENTITY}/{PROJECT}/{sweep_id}").runs:
        config = unnest_dict(run.config)
        config['sweep_id'] = run.sweep.id
        summary = dict(run.summary)
        series.append(pd.Series(config | summary, name=run.name))

    df = pd.DataFrame(series)
    df.to_csv(os.path.join(RESULTS_DIR, f"{sweep_id}.csv"))
    return df


def get_runs_by_ids(run_ids, api=None):
    if api is None:
        api = wandb.Api()
    for run_id in run_ids:
        run = api.run(f"{ENTITY}/{PROJECT}/{run_id}")
        yield run


def preprocess_df(df, get_config_cols=False):
    # for run in runs:
    #     config = unnest_dict(run.config)
    #
    #     for key, val in config.items():
    #         if key.endswith('dim'):
    #             config[key] = str(val)
    #     config_cols.update(set(config.keys()))
    #
    #     summary = dict(run.summary)
    #     summary['sweep_id'] = run.sweep.id
    #
    #     series = pd.Series(config | summary, name=run.name)
    #     series.rename(index=COLUMN_RENAMES, inplace=True)
    #     all_series.append(series)
    # df = pd.DataFrame(all_series)

    df = df.rename(columns=COLUMN_RENAMES)
    dim_cols = df.columns[df.columns.str.endswith('dim')]
    df.loc[:, dim_cols] = df.loc[:, dim_cols].astype(str)

    config_cols = {col for col in df.columns if col.startswith(('data', 'model', 'fairret', 'method', 'seed'))}

    zero_strength_runs = df['model.fairret_strength'] == 0.
    df.loc[zero_strength_runs, 'method'] = 'unfair'
    config_cols.add('method')
    df.loc[zero_strength_runs, 'fairret.stat'] = 'none'
    config_cols.add('fairret.stat')
    df.loc[zero_strength_runs, f'train/fairret_loss'] = 0.
    df.loc[zero_strength_runs, f'test/fairret_loss'] = 0.

    adv_debias_runs = df['model.name'] == 'adv_debias'
    if adv_debias_runs.any():
        df.loc[adv_debias_runs, 'method'] = 'adv_debias'
        df.loc[adv_debias_runs & (df['model.label_given'] == False), 'fairret.stat'] = 'none'
        df = df[(~adv_debias_runs) | (df['model.label_given'] == False)]
        # df.loc[adv_debias_runs & (df['model.label_given'] == False), 'fairret.stat'] = 'pr'
        # df.loc[adv_debias_runs & (df['model.label_given'] == True), 'fairret.stat'] = 'tpr'

    if 'fairret.loss_cfg.ffb_name' in df.columns:
        ffb_fairret_runs = df['method'] == 'ffb'
        df.loc[ffb_fairret_runs, 'method'] = df.loc[ffb_fairret_runs, 'fairret.loss_cfg.ffb_name']
        df.loc[ffb_fairret_runs, 'fairret.stat'] = 'none'

    if get_config_cols:
        return df, list(config_cols)
    else:
        return df


def get_checkpoint_path(run_name, api=None):
    if api is None:
        api = wandb.Api()

    artifact = api.artifact(f"{ENTITY}/{PROJECT}/model-{run_name}:latest")
    checkpoint_dir = os.path.join(LOGS_DIR, PROJECT, run_name, "checkpoints")
    artifact.download(root=checkpoint_dir)
    checkpoint_path = os.path.join(checkpoint_dir, "model.ckpt")
    return checkpoint_path

