import os
import time
from typing import List, Optional

import pandas as pd
import rich
import wandb
from rich import print
from rich.console import Console
from rich.table import Table
from tqdm.auto import tqdm


def login(project: Optional[str] = None, token: Optional[str] = None):
    if project is None:
        project = os.getenv("WANDB_PROJECT")
    if token is None:
        token = os.getenv("WANDB_API_KEY")
    # Set your wandb API key
    wandb.login(key=token)


def fetch_runs(
    project_list,
    exp_name_list: Optional[List[str]] = None,
    exp_term_list: Optional[List[str]] = None,
):
    # Fetch runs from wandb
    api = wandb.Api()
    runs = (
        api.runs(path=project_list)
        if isinstance(project_list, str)
        else [api.runs(path=p) for p in project_list]
    )
    # runs = [run for run_list in runs for run in run_list]

    exp_name_to_summary_dict = {}
    exp_name_to_state_dict = {}
    idx = 0
    for run in tqdm(runs[0]):
        if "exp_name" not in run.config:
            continue
        exp_name = run.config["exp_name"].lower()

        if exp_name_list is not None and exp_name not in exp_name_list:
            continue

        if exp_term_list is not None:
            if not any(term in exp_name for term in exp_term_list):
                continue

        if exp_name not in exp_name_to_summary_dict:
            exp_name_to_summary_dict[exp_name] = [run.summaryMetrics]
            exp_name_to_state_dict[exp_name] = [run.state]
        else:
            exp_name_to_summary_dict[exp_name].append(run.summaryMetrics)
            exp_name_to_state_dict[exp_name].append(run.state)
        idx += 1
        if idx > 25:
            break
    return exp_name_to_summary_dict, exp_name_to_state_dict


def fetch_run_status(exp_name_to_summary_dict, exp_name_to_state_dict):
    exp_data = {}
    exp_to_command = {
        key.lower(): value for key, value in exp_name_to_summary_dict.items()
    }
    if isinstance(exp_name_to_summary_dict, dict):
        experiments = list(exp_name_to_summary_dict.keys())
    experiments = [exp.lower() for exp in experiments]
    # Iterate through the given experiments
    with tqdm(total=len(experiments), desc="Checking experiments") as pbar:
        for exp_name in experiments:
            # Check if the experiment exists in wandb and if it has completed the testing stage
            if exp_name in exp_name_to_summary_dict:
                keys = [
                    k
                    for summary_keys in exp_name_to_summary_dict[exp_name]
                    for k in summary_keys.keys()
                ]

                testing_completed = any("testing/ensemble" in k for k in keys)
                model_compiled = any("model/num_parameters" in k for k in keys)

                currently_running = any(
                    "running" == state.lower()
                    for state in exp_name_to_state_dict[exp_name]
                )
                if "global_step" in keys:
                    current_iter = max(
                        [
                            summary_stats["global_step"]
                            for summary_stats in exp_name_to_summary_dict[
                                exp_name
                            ]
                            if "global_step" in summary_stats.keys()
                        ]
                    )
                else:
                    current_iter = 0

            else:
                testing_completed = False
                currently_running = False
                current_iter = 0

            # Append the data to the list
            exp_data[exp_name] = {
                "testing_completed": testing_completed,
                "currently_running": currently_running,
                "current_iter": current_iter,
                "model_compiled": model_compiled,
                "command": exp_to_command[exp_name],
            }
            pbar.update(1)
    return exp_data


def pretty_print_runs(exp_data):

    # Create a pandas DataFrame
    df = pd.DataFrame(
        exp_data
    ).T  # Transpose the DataFrame so that each experiment is a row

    # Create a console for rich print
    console = Console()

    # Create a table
    table = Table(show_header=True, header_style="bold magenta", style="dim")
    table.add_column("idx", justify="right")
    table.add_column("Experiment Name", width=50)
    table.add_column("Currently Running", justify="right")
    table.add_column("Testing Completed", justify="right")
    table.add_column("Current Iteration", justify="right")
    table.add_column("Model Compiled", justify="right")

    # Add rows to the table
    for idx, (exp_name, row) in enumerate(df.iterrows()):
        table.add_row(
            str(idx),
            exp_name,
            str(row["currently_running"]),
            str(row["testing_completed"]),
            str(row["current_iter"]),
            str(row["model_compiled"]),
        )

    # Print the table
    console.print(table)


class WandbEinServer:
    # 1. Local means of launching experiments either on GPU machine on cluster
    # a. Sample architecture/mutate, compile to check for viability -> Get population
    # b. Pick an individual, and generate training CLI command
    # c. Send to GPU machine -> wandb is launched
    # 2. A WandbServer monitors experiment progress, and when xx% of the population is done, it generates a new population, remember local and remote wait times (local when sending commands, and remote sitting in queue and then running on machine)
    # 3. Repeat until convergence
    def __init__(
        self,
        project_list: list,
        exp_name_list: list,
        exp_term_list: list,
        pretty_print: bool = True,
    ):
        super().__init__()
        self.project_list = project_list
        self.exp_name_list = exp_name_list
        self.exp_term_list = exp_term_list
        self.pretty_print = pretty_print

    def run(self):
        # we distinguish current generation via the exp_name or exp_term
        # we fetch the runs from wandb
        # we check if the testing is completed
        while True:
            exp_name_to_summary_dict, exp_name_to_state_dict = fetch_runs(
                project_list=self.project_list,
                exp_name_list=self.exp_name_list,
                exp_term_list=self.exp_term_list,
            )
            exp_data = fetch_run_status(
                exp_name_to_summary_dict, exp_name_to_state_dict
            )
            if self.pretty_print:
                pretty_print_runs(exp_data)
            # if all([exp_data[exp]["testing_completed"] for exp in exp_data]):
            #     break
            yield exp_name_to_summary_dict, exp_name_to_state_dict, exp_data
            time.sleep(60)


if __name__ == "__main__":
    project = "your_project"
    token = None
    server = WandbEinServer(
        project_list=[project],
        exp_name_list=None,
        exp_term_list=["ein", "cifar100", "clevr"],
    )
    for (
        exp_name_to_summary_dict,
        exp_name_to_state_dict,
        exp_data,
    ) in server.run():
        pass
