import matplotlib.axes
import torch
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import wandb

import os
from itertools import product, chain
from typing import Literal, Generator, Tuple
import argparse

from qtorch.quantumstate.measurements import getZDistribution, getExpectation
from .data import ErdosRenyiDataset

ExptSetting = Tuple[
    Literal['Macro','Micro'],
    Literal['RhoDARTS','QDARTS'],
    float, 
    bool,
    Literal['None','BitPhaseFlip','Depolarizing'],
    float
]

def experiment_loop()->Generator[ExptSetting,None,None]:
    for (search_size, search_type, edgeProb, 
         hu, noise_model, noiseProb) in product(
            ['Macro', 'Micro'],
            ['RhoDARTS', 'QDARTS'], 
            [0.25, 0.5, 0.75], 
            [False, True],
            ['None', 'BitPhaseFlip', 'Depolarizing'],
            [0.01, 0.1, 0.25, 0.5]
        ):
            if search_size == 'Micro' and noise_model != 'None':
                continue
            if search_type == 'QDARTS' and noise_model == 'Depolarizing':
                continue
            if noise_model == 'None' and noiseProb > 0.01:
                continue
            yield (search_size, search_type, edgeProb, hu, noise_model, noiseProb)

def getArtifactName(search_size:Literal['Macro','Micro'], 
                    search_type:Literal['RhoDARTS','QDARTS'], 
                    edgeProb:float,
                    hu:bool, 
                    noise_model:Literal['None','BitPhaseFlip','Depolarizing'],
                    noiseProb:float):
    artifact_name = (f'{search_type}'
                     f'{"-Micro" if search_size == "Micro" else ""}-'
                     f'MaxCut-Circuits-p-{edgeProb:0.2f}'
                     f'-With{"" if hu else "out"}-Hidden-Units')
    if noise_model != 'None':
        artifact_name += f'-{noise_model}-p-{noiseProb:0.2f}'
    return artifact_name

def getFolder(search_size:Literal['Macro','Micro'], 
              search_type:Literal['RhoDARTS','QDARTS'], 
              edgeProb:float,
              hu:bool, 
              noise_model:Literal['None','BitPhaseFlip','Depolarizing'],
              noiseProb:float)->str:
    folder = (f'{search_size}/{search_type}/'
              f'{"hu" if hu else "direct"}/')
    if noise_model == 'None':
        folder += 'Noiseless/'
    else:
        folder += f'{noise_model}/noiseProb_{noiseProb:0.2f}/'
    return folder

def download_run_data(PROJECT_NAME:str, LOCAL_ARTIFACT_ROOT_DIR:str)->list[str]:
    wandb.login()
    if not os.path.exists(LOCAL_ARTIFACT_ROOT_DIR):
        os.makedirs(LOCAL_ARTIFACT_ROOT_DIR)
    
    na_artifact_list = []
    
    with wandb.init(project=PROJECT_NAME,job_type='downloader') as run:
        for expt_settings in experiment_loop():
            artifact_name = getArtifactName(*expt_settings)
            
            folder = os.path.join(LOCAL_ARTIFACT_ROOT_DIR,getFolder(*expt_settings))
            edgeProb = expt_settings[2]
            
            if os.path.exists(os.path.join(folder,f'angles_p_{edgeProb:0.2f}.pt')):
                print(f'Artifact `{artifact_name}` already downloaded')
                continue
            
            try:
                artifact:wandb.Artifact = run.use_artifact(
                    artifact_name+":latest",
                    type='circuit-data'
                )
                if not os.path.exists(folder):
                    os.makedirs(folder)
                artifact_dir = artifact.download(folder)
                    
            except wandb.CommError as e:
                print(f'Artifact `{artifact_name}` not found, skipping')
                na_artifact_list.append(artifact_name)
    return na_artifact_list

def download_dataset(PROJECT_NAME:str, LOCAL_DATASET_DIR:str)->str:
    wandb.login()
    if not os.path.exists(LOCAL_DATASET_DIR):
        os.makedirs(LOCAL_DATASET_DIR)
    with wandb.init(project=PROJECT_NAME, job_type='download_dataset') as run:
        artifact:wandb.Artifact = run.use_artifact('graph-datasets:latest',type='dataset')
        artifact_dir = artifact.download(LOCAL_DATASET_DIR)
    print(f'Graph Datasets downloaded to `{artifact_dir}`')
    return artifact_dir

def build_dataframe(RUN_DATA_DIR:str,
                    GRAPH_DATASET_DIR:str)->pd.DataFrame:
    graphs = {
        edgeProb: ErdosRenyiDataset(os.path.join(GRAPH_DATASET_DIR, 
                                             f'p_{edgeProb:0.2f}.pt'))
        for edgeProb in [0.25, 0.5, 0.75]
    }

    # Storage for results
    results = []

    # Iterate through all experiment settings
    for expt_setting in experiment_loop():
        search_size, search_type, edgeProb, hu, noise_model, noiseProb = expt_setting
        folder = os.path.join(RUN_DATA_DIR, getFolder(*expt_setting))
        
        try:
            states_path = os.path.join(folder, f'states_p_{edgeProb:0.2f}.pt')
            if not os.path.exists(states_path):
                continue  # Skip missing results

            states = torch.load(states_path, weights_only=True)
            dataset = graphs[edgeProb]

            for g_i in range(10):  # assuming 10 graphs
                G = dataset[g_i]
                probs = getZDistribution(states[g_i])
                correct_prob = probs[G['max_cut_bases']].sum().item()
                energy = getExpectation(states[g_i], G['hamiltonian']) / G['max_cut_value']
                energy = energy.item()

                results.append({
                    'search_size': search_size,
                    'search_type': search_type,
                    'edge_prob': edgeProb,
                    'use_hidden_units': hu,
                    'noise_model': noise_model,
                    'noise_prob': noiseProb,
                    'graph_index': g_i,
                    'correct_prob': correct_prob,
                    'normalized_energy': energy
                })
        except Exception as e:
            print(f"Error processing {expt_setting}: {e}")

    # Convert to DataFrame
    return pd.DataFrame(results)

def noiseless_stat_plot(df:pd.DataFrame, ax:matplotlib.axes.Axes,
                        metric:Literal['correct_prob', 'normalized_energy'],
                        search_size:Literal['Macro','Micro'],
                        )->None:
    n_df = df[(df['noise_model'] == 'None')
                      &(df['search_size'] == search_size)].copy()
    n_df['setting'] = (
        n_df['search_type'].replace('Rho','ρ').replace('Q','q') +
        n_df['use_hidden_units'].map({True: '-HU', False: ''})
    )

    sns.boxplot(data=n_df,x='setting',y=metric,hue='setting',ax=ax)
    ax.set_ylim(0,1.05)
    ax.set_title(f'{("Probability of Measuring GT" if metric == "correct_prob" else "Normalized Energy")} ({search_size} Search)')

def get_noiseless_df(df)->pd.DataFrame:
    summary_df = df[
        df['noise_model'] == 'None'
    ].groupby(
        ['search_size', 'search_type', 'use_hidden_units']
    ).agg(
        norm_energy_mean=('normalized_energy', 'mean'),
        norm_energy_std=('normalized_energy', 'std'),
        correct_prob_mean=('correct_prob', 'mean'),
        correct_prob_std=('correct_prob', 'std'),
    ).reset_index()
    return summary_df

def generate_latex_metric_table(summary_df):
    table_rows = []

    for search_size in ['Macro', 'Micro']:
        size_header = r'\multirow{2}{*}{'+search_size+' Search}'
        for i,(metric,label) in enumerate([('norm_energy',r'$E_m$'), ('correct_prob',r'$P_m$')]):
            row_label = [f"{size_header if i == 0 else ''}", label]
            entries = []

            # Collect entries and their means for comparison
            cell_data = []
            for search in ['RhoDARTS', 'QDARTS']:
                for hu in [False, True]:
                    match = summary_df[
                        (summary_df['search_type'] == search) &
                        (summary_df['use_hidden_units'] == hu)
                    ]
                    if not match.empty:
                        mean = match[f'{metric}_mean'].values[0]
                        std = match[f'{metric}_std'].values[0]
                        cell_data.append((mean, std))
                    else:
                        cell_data.append((None, None))

            # Find index of max mean (ignoring Nones)
            valid_means = [m for m, _ in cell_data if m is not None]
            max_mean = max(valid_means) if valid_means else None

            for mean, std in cell_data:
                if mean is not None:
                    mean_fmt = f'{mean:0.4f}'
                    std_fmt =  f'{std:0.4f}'
                    cell = f"${mean_fmt} \\pm {std_fmt}$"
                    if mean == max_mean:
                        cell = f"$\\mathbf{{{mean_fmt} \\pm {std_fmt}}} $"
                else:
                    cell = ""
                entries.append(cell)

            table_rows.append(" & ".join(row_label + entries) + " \\\\")
        if search_size == 'Macro':
            table_rows.append(r'\midrule\midrule')
        else:
            table_rows.append(r'\bottomrule')
    return "\n".join(table_rows)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Data Analysis for the MaxCut ' 
                                     'Runs')
    parser.add_argument('--project-name', type=str,
                        help='Name of the W&B project where the artifacts are '
                        'stored.')
    parser.add_argument('--local-artifact-path', type=str,
                        help='Local directory to store downloaded artifacts.')
    parser.add_argument('-d', '--download-runs', action='store_true',
                        help='If set, download the run data.')
    parser.add_argument('-g', '--download-graphs', action='store_true',
                        help='If set, download the graph dataset.')
    args = parser.parse_args()
    print(args.project_name)
    
    if args.download_runs == args.download_graphs == True:
        print('Error: Cannot set -d and -g at the same time')
    
    if args.download_runs:
        na_artifact_list = download_run_data(args.project_name, args.local_artifact_path)
        print('List of artifacts that are not ready:')
        for a_name in na_artifact_list:
            print(a_name)
    if args.download_graphs:
        download_dataset(args.project_name, args.local_artifact_path)

