import os
import torch
import pickle
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict

def load_client_label_map(datapath, data,part, num_clients,beta=1.0):
    if part == "dirichlet":
        pkl_str = f'{data}_disjoint_{part}_{beta}/{num_clients}/client_label_map.pkl'
    else:
        pkl_str = f'{data}_disjoint_{part}/{num_clients}/client_label_map.pkl'

    client_class_map = pickle.load(open(os.path.join(datapath,pkl_str), 'rb'))
    return client_class_map

def get_client_map(dataset, num_clients):
    if dataset == "Cora":
        if num_clients == 3:
            clients_class_map = {
                "client_0": [0, 1, 2, 3, 4, 5, 6],
                "client_1": [0, 1, 2, 3, 4, 5, 6],
                "client_2": [0, 1, 2, 3, 4, 5, 6],
            }
        if num_clients ==5:
            clients_class_map = {
                "client_0": [0, 1, 2, 3, 4],
                "client_1": [0, 1, 2, 3, 4, 5],
                "client_2": [1,2,3,5],
                "client_3": [0, 1, 2, 3, 4, 5],
                "client_4": [0, 1, 2, 3, 4, 5, 6],
            }
        elif num_clients == 10:
            clients_class_map = {
                "client_0": [3, 4, 5],
                "client_1": [0, 1, 3, 4],
                "client_2": [0,1,2,3,4,5,6],
                "client_3": [0,1,2,3,4,5,6],
                "client_4": [0,1,2,3,4,5,6],
                "client_5": [0,1,2,3,4],
                "client_6": [0,3,4],
                "client_7": [0,1,2,3,4,5],
                "client_8": [0,1,2,3,5],
                "client_9": [1,2,3],
            }
        elif num_clients ==20:
            clients_class_map = {
                "client_0": [0, 1, 2, 3, 4, 5],
                "client_1": [0, 1, 2, 3, 4],
                "client_2": [0, 5 ,6],
                "client_3": [3,4,5],
                "client_4": [1,3,4,5],
                "client_5": [0,3,4],
                "client_6": [0,1,3,4,5,6],
                "client_7": [0,1,2,3,4,5,6],
                "client_8": [0,1,3,4,5,6],
                "client_9": [0,2,3,4,5,6],
                "client_10": [0,1,2,3,4],
                "client_11": [0,1,3,4],
                "client_12": [0,3,4],
                "client_13": [0,1,3,5],
                "client_14": [0,1,2,3,4],
                "client_15": [0, 2, 3],
                "client_16": [0,2,3,4,5],
                "client_17": [2,3],
                "client_18": [0,1,2,5],
                "client_19": [1,2,3 ,5 ],
            }
    elif dataset == "CiteSeer":
            if num_clients == 3:
                clients_class_map = {
                    "client_0": [0, 1, 2, 3, 4, 5],
                    "client_1": [0, 1, 2, 3, 4, 5],
                    "client_2": [0, 1, 2, 3, 4, 5],
                }
            elif num_clients ==5:
                clients_class_map = {
                    "client_0": [0, 1, 2, 3, 4, 5],
                    "client_1": [0, 1, 2, 3, 4, 5],
                    "client_2": [0, 1, 2, 3, 4, 5],
                    "client_3": [0, 1, 2, 3, 4, 5],
                    "client_4": [0, 1, 2, 3, 4, 5],
                }
            elif num_clients == 10:
                clients_class_map = {
                    "client_0": [0,2,3,4],
                    "client_1": [0, 1, 2, 3, 4, 5],
                    "client_2": [0, 1, 2, 3, 4, 5],
                    "client_3": [0, 1, 2, 3, 4, 5],
                    "client_4": [0, 1, 2, 3, 4, 5],
                    "client_5": [0, 1, 2,  4, 5],
                    "client_6": [0,  2, 3, 4, 5],
                    "client_7": [0, 1, 2, 3, 4, 5],
                    "client_8": [0, 1, 2, 3, 4, 5],
                    "client_9": [0, 1, 2, 3, 4, 5],
                }
            elif num_clients ==20:
                clients_class_map = {
                    "client_0": [ 1, 2, 3, 4, 5],
                    "client_1": [0, 3, 4, 5],
                    "client_2": [0, 2, 3, 4, 5],
                    "client_3": [0, 1, 2, 3, 4, 5],
                    "client_4": [0, 1, 2,  4, 5],
                    "client_5": [0, 1, 3, 4, 5],
                    "client_6": [0, 1, 2, 3, 4, 5],
                    "client_7": [0, 1, 2, 4, 5],
                    "client_8": [1, 2, 3, 4, 5],
                    "client_9": [0,  3,  5],
                    "client_10": [0, 1, 2, 3, 5],
                    "client_11": [0, 1, 2, 3, 4, 5],
                    "client_12": [1, 2, 3, 4, 5],
                    "client_13": [0, 1, 2, 3, 5],
                    "client_14": [1, 2, 3, 4, 5],
                    "client_15": [0, 1, 2, 3, 4, 5],
                    "client_16": [1, 2, 3],
                    "client_17": [0, 2,3],
                    "client_18": [0,2,3,4],
                    "client_19": [3,1,2,3],
                }
    elif dataset == "PubMed":
                if num_clients == 3:
                    clients_class_map = {
                        "client_0": [0, 1, 2],
                        "client_1": [0, 1, 2],
                        "client_2": [0, 1, 2],
                    }
                if num_clients ==5:
                    clients_class_map = {
                        "client_0": [0, 1, 2],
                        "client_1": [0, 1, 2],
                        "client_2": [0, 1, 2],
                        "client_3": [0, 1, 2],
                        "client_4": [0, 1, 2],
                    }
                elif num_clients == 10:
                    clients_class_map = {
                        "client_0": [0, 1, 2],
                        "client_1": [0, 1, 2],
                        "client_2": [0, 1, 2],
                        "client_3": [0, 1, 2],
                        "client_4": [0, 1, 2],
                        "client_5": [0, 1, 2],
                        "client_6": [0, 1, 2],
                        "client_7": [0, 1, 2],
                        "client_8": [0, 1, 2],
                        "client_9": [0, 1, 2],
                    }
                elif num_clients ==20:
                    clients_class_map = {
                        "client_0": [0, 1, 2],
                        "client_1": [0, 1, 2],
                        "client_2": [0, 1, 2],
                        "client_3": [0, 1, 2],
                        "client_4": [0, 1, 2],
                        "client_5": [0, 1, 2],
                        "client_6": [0, 1, 2],
                        "client_7": [0, 1, 2],
                        "client_8": [0, 1, 2],
                        "client_9": [0, 1, 2],
                        "client_10": [0, 1, 2],
                        "client_11": [0, 1, 2],
                        "client_12": [0, 1, 2],
                        "client_13": [0, 1, 2],
                        "client_14": [0, 1, 2],
                        "client_15": [0, 1, 2],
                        "client_16": [0, 1, 2],
                        "client_17": [0, 1, 2],
                        "client_18": [0, 1, 2],
                        "client_19": [0, 1, 2],
                    }
    elif dataset == "arxiv":
        clients_class_map = { f"client_{i}" : [a for a in range(40)] for i in range(num_clients)}

    return clients_class_map


def load_scores(experiment: Path = None, dataset=None) -> dict:
    try:
        load = lambda p: torch.load(p, map_location=torch.device("cpu"))
        stage = "stage2" if "tct" in experiment.name else "stage1"

        val_scores = load(*(experiment / "scores").glob(f"*_{stage}_val_scores.pth"))
        val_targets = load(*(experiment / "scores").glob(f"*_{stage}_val_targets.pth"))
        test_scores = load(*(experiment / "scores").glob(f"*_{stage}_test_scores.pth"))
        test_targets = load(
            *(experiment / "scores").glob(f"*_{stage}_test_targets.pth")
        )
        return dict(
            val_scores=val_scores,
            val_targets=val_targets,
            test_scores=test_scores,
            test_targets=test_targets,
        )
    except Exception as e:
        print(e)
        return None


def get_new_trial(experiments, key = "tct", frac=0.5, fitzpatrick_df=None):
    # key is the algo we use, tct  fedavg , fedprox, etc. 
    orig_val_scores = experiments[key]["val_scores"]
    orig_val_targets = experiments[key]["val_targets"]
    orig_test_scores = experiments[key]["test_scores"]
    orig_test_targets = experiments[key]["test_targets"]
    orig_comb_scores = torch.cat([orig_val_scores, orig_test_scores])
    orig_comb_targets = torch.cat([orig_val_targets, orig_test_targets])
    assert orig_comb_scores.size(0) == orig_comb_targets.size(0)
    n = orig_comb_scores.size(0)
    rand_index = torch.randperm(n)
    k = int(frac * n)
    val_index = rand_index[:k]
    test_index = rand_index[k:]
    assert val_index.shape[0] + test_index.shape[0] == n
    new_experiments = {}
    for exp, v in experiments.items():
        # print(exp)
        val_scores = v["val_scores"]
        val_targets = v["val_targets"]
        test_scores = v["test_scores"]
        test_targets = v["test_targets"]
        comb_scores = torch.cat([val_scores, test_scores])
        comb_targets = torch.cat([val_targets, test_targets])
        assert (comb_targets == orig_comb_targets).all(), exp
        assert comb_targets.sum() == orig_comb_targets.sum(), exp
        new_experiments[exp] = {
            "val_scores": comb_scores[val_index],
            "val_targets": comb_targets[val_index],
            "test_scores": comb_scores[test_index],
            "test_targets": comb_targets[test_index],
        }
    if fitzpatrick_df is not None:
        val_df = fitzpatrick_df.copy().loc[val_index]
        test_df = fitzpatrick_df.copy().loc[test_index]
        return dict(experiments=new_experiments, val_df=val_df, test_df=test_df)
    else:
        return dict(experiments=new_experiments, val_df=None, test_df=None)


def combine_trials(trials):
    metrics = set(list(trials.values())[0].keys())
    mean_metrics = {met: defaultdict(list) for met in metrics}
    std_metrics = {met: defaultdict(list) for met in metrics}

    for trial in trials.values():
        for met, res in trial.items():
            # print(res)
            for alpha, val in res.items():
                mean_metrics[met][alpha].append(val)
                std_metrics[met][alpha].append(val)
        # break

    for met, dd in mean_metrics.items():
        mean_metrics[met] = {alpha: np.mean(values) for alpha, values in dd.items()}
    for met, dd in std_metrics.items():
        std_metrics[met] = {alpha: np.std(values) for alpha, values in dd.items()}

    return dict(mean=mean_metrics, std=std_metrics)
