import torch
import wandb
import pandas as pd

import os
from itertools import product

from qtorch.quantumstate import fidelity
from .utils import GHZ_State, W_State

def expt_loop(noise:bool=False):
    for search_type, entanglement_type, hu in product(
        ['RhoDARTS', 'QDARTS'],['GHZ', 'W'],[False, True]
        ):
        if not noise:
            yield (search_type, entanglement_type, hu, 'None', 0.0)
        else:
            for noise_model, noise_prob in product(
                ['BitPhaseFlip', 'Depolarizing'],
                [0.01, 0.10, 0.25, 0.50]
                ):
                if noise_model == 'Depolarizing' and search_type == 'QDARTS':
                    continue
                yield(search_type, entanglement_type, hu, noise_model, noise_prob)
        

def download_data(root:str, project_name:str,
                  download:bool=True,versions:int=3,
                  noise:bool=False)->list[str]:
    if download:
        with wandb.init(project=project_name, name='downloader', 
                        job_type='downloader') as run:
            artifact_paths = dict()
            for search_type, entanglement_type, hu, noise_model, noise_prob in expt_loop(noise):
                artifact_name = f'{search_type}-{entanglement_type}-Circuits-{"With" if hu else "Without"}-Hidden-Units'
                if noise_model != 'None':
                    artifact_name += f'-{noise_model}-p-{noise_prob:0.2f}'
                artifact_path = os.path.join(root, artifact_name)
                for v in range(versions):
                    v_path = artifact_path+f"-v{v}"
                    artifact_paths[(search_type, entanglement_type, hu, noise_model, noise_prob,v)] = v_path
                    os.makedirs(v_path,exist_ok=True)
                    v_name = artifact_name + f":v{v}"
                    artifact:wandb.Artifact = run.use_artifact(v_name)
                    print(v_name)
                    artifact_dir = artifact.download(v_path)
    else:
        artifact_paths = dict()
        for search_type, entanglement_type, hu, noise_model, noise_prob in expt_loop(noise):
            artifact_name = f'{search_type}-{entanglement_type}-Circuits-{"With" if hu else "Without"}-Hidden-Units'
            if noise_model != 'None':
                artifact_name += f'-{noise_model}-p-{noise_prob:0.2f}'
            artifact_path = os.path.join(root, artifact_name)
            for v in range(versions):
                artifact_paths[(search_type, entanglement_type, hu, noise_model, noise_prob,v)] = (artifact_path+f"-v{v}")
    return artifact_paths

def make_dataframe(artifact_paths:dict)->pd.DataFrame:
    data = []
    for (search_type, entanglement_type, hu, noise_model, noise_prob, v), path in artifact_paths.items():
        for n in range(2,7):
            state = torch.load(os.path.join(path, f'state_{n}.pt'),weights_only=True)
            if entanglement_type == 'GHZ':
                ref_state = GHZ_State(n, state.device)
            else:
                ref_state = W_State(n, state.device)
            fid = fidelity(state[0], ref_state).item()
            data.append({
                'search_type': search_type,
                'entanglement_type': entanglement_type,
                'hidden_units': hu,
                'noise_model': noise_model,
                'noise_probability': noise_prob,
                'n_qubits': n,
                'fidelity': fid
            })
    return pd.DataFrame(data)

def get_noiseless_runs(df:pd.DataFrame):
    summary_df = df[
        df['noise_model'] == 'None'  # filter out only noiseless runs
    ].groupby(
        ['search_type', 'entanglement_type', 'n_qubits', 'hidden_units']
    ).agg(
        fidelity_mean=('fidelity', 'mean'),
        fidelity_std=('fidelity', 'std')
    ).reset_index()
    return summary_df

def generate_latex_table_with_bold(summary_df, qubit_range=range(2, 7)):
    table_rows = []

    for state in ['GHZ', 'W']:
        state_header = r'\multirow{5}{*}{'+state+'}'
        for n in qubit_range:
            row_label = [f"{state_header if n == 2 else ''}", f"{n}"]
            entries = []

            # Collect entries and their means for comparison
            cell_data = []
            for search in ['RhoDARTS', 'QDARTS']:
                for hu in [False, True]:
                    match = summary_df[
                        (summary_df['search_type'] == search) &
                        (summary_df['entanglement_type'] == state) &
                        (summary_df['n_qubits'] == n) &
                        (summary_df['hidden_units'] == hu)
                    ]
                    if not match.empty:
                        mean = match['fidelity_mean'].values[0]
                        std = match['fidelity_std'].values[0]
                        cell_data.append((mean, std))
                    else:
                        cell_data.append((None, None))

            # Find index of max mean (ignoring Nones)
            valid_means = [m for m, _ in cell_data if m is not None]
            max_mean = max(valid_means) if valid_means else None

            for mean, std in cell_data:
                if mean is not None:
                    mean_fmt = f'{mean:0.4f}'
                    std_fmt =  f'{std:0.4f}'
                    cell = f"${mean_fmt} \\pm {std_fmt}$"
                    if mean == max_mean:
                        cell = f"$\\mathbf{{{mean_fmt} \\pm {std_fmt}}} $"
                else:
                    cell = ""
                entries.append(cell)

            table_rows.append(" & ".join(row_label + entries) + " \\\\")
        if state == 'GHZ':
            table_rows.append(r'\midrule\midrule')
        else:
            table_rows.append(r'\bottomrule')
    return "\n".join(table_rows)

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser('Entangled State Data Analysis')
    parser.add_argument('--project-name',type=str,required=True,
                        help='WANDB project name')
    parser.add_argument('--local-artifact-path',type=str,default='./artifacts/entangled-states-data',
                        help='Where to store the downloaded artifacts')
    parser.add_argument('-n','--noise',action='store_true',help='Whether to include the noisy experiments')
    parser.add_argument('-v','--num-versions',type=int,
                        help='Number of artifacts in each experiment', 
                        default=3)
    parser.add_argument('-d', '--download', action='store_true', 
                        help='Whether to download the data')
    args = parser.parse_args()

    wandb.login()
    artifact_paths = download_data(args.local_artifact_path, args.project_name, 
                                   args.download, args.num_versions, args.noise)
    df = make_dataframe(artifact_paths)
    print(df)