import torch
import wandb

import os
from typing import Literal,Optional
from functools import partial

from qtorch.quantumstate.measurements import getZDistribution

from .data import ErdosRenyiDataset
from .plots import make_max_cut_plot
from .loss import loss
from .utils import getEntropySchedule


from ..qas import QAS, RhoDARTS, QDARTS

def configure_sweep(N:int, numVertices:int, edgeProbability:float, 
                    numLayers:int,
                    use_scheduler:bool,
                    use_hidden_units:bool,
                    search_type:Literal['RhoDARTS','QDARTS'],
                    PROJECT_NAME:str,
                    LOCAL_ARTIFACT_PATH:str,
                    NUM_RUNS:int,
                    DSET_ARTIFACT_NAME:str='graph-datasets',
                    cuda_device:int=0
                    )->None:
    # Hyperparameter Values
    learning_rate = [1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 5e-1]
    CAS_Tmax = [N//10, N//4, N//2, N]
    entropy_penalty_str = [1e-3,1e-2,1e-1]
    angle_penalty_str = [1e-3,1e-2,1e-1]
    schedule_start_point = [-1.0, 0.0]
    schedule_start_time = [0, N//4, N//2]
    schedule_duration = ['full', 'half']
    num_oscillations = [1, 2, 3]
    gumbel_temp = [1e-3, 1e-2, 1e-1]
    num_hidden_units = [(numVertices+3)//2, numVertices+3, 2*(numVertices+3)]

    parameters = dict(
        # Experiment settings
        num_iterations       = {'value': N},
        num_vertices         = {'value': numVertices},
        num_layers           = {'value': numLayers},
        edge_probability     = {'value': edgeProbability},
        use_scheduler        = {'value': use_scheduler},
        use_hidden_units     = {'value': use_hidden_units},
        search_type          = {'value': search_type},

        # Hyperparameters
        learning_rate        = {'values': learning_rate},
        CAS_Tmax             = {'values': CAS_Tmax},
        angle_penalty_str    = {'values': angle_penalty_str},
        num_hidden_units     = ({'values': num_hidden_units} if 
                                use_hidden_units else {'value': 0}),
        # Conditioned on using the entropy scheduler
        entropy_penalty_str  = ({'values': entropy_penalty_str} if 
                                use_scheduler else {'value': None}),
        schedule_start_point = ({'values': schedule_start_point} if 
                                use_scheduler else {'value': None}),
        schedule_start_time  = ({'values': schedule_start_time} 
                                if use_scheduler else {'value': None}),
        schedule_duration    = ({'values': schedule_duration} if 
                                use_scheduler else {'value': None}),
        num_oscillations     = ({'values': num_oscillations} 
                                if use_scheduler else {'value': None}),
        # Conditioned on `search_type`
        gumbel_temp          = ({'values': gumbel_temp} if 
                                search_type == 'QDARTS' else {'value': None})
    )

    sweep_name = f'{search_type} | {"With" if use_scheduler else "Without"} '\
                 f'Scheduler | p={edgeProbability:0.2f} | {"With" if use_hidden_units else "Without"} Hidden Units | HyperParameter Sweep'
    
    sweep_config = dict(
        name=sweep_name,
        method='random',
        metric={'name': 'energy',
                'goal': 'minimize'},
        parameters=parameters
    )

    expt_fn = partial(run_experiment,
                      LOCAL_ARTIFACT_PATH=LOCAL_ARTIFACT_PATH,
                      DSET_ARTIFACT_NAME=DSET_ARTIFACT_NAME,
                      cuda_device=cuda_device)

    sweep_id = wandb.sweep(sweep_config, project=PROJECT_NAME)
    wandb.agent(sweep_id, expt_fn, count=NUM_RUNS)
    wandb.finish()

def make_dataset(numVertices:int, edgeProbs:float|list[int], num_graphs:int,
                 PROJECT_NAME:str, DSET_ARTIFACT_NAME:str='graph-datasets')->None:
    if isinstance(edgeProbs,float):
        edgeProbs = [edgeProbs]
    
    run = wandb.init(project=PROJECT_NAME, job_type='make-dataset')
    description = (
        f'A collection of Erdos-Renyi Graphs with {numVertices} vertices along '
        'with their '
        'corresponding max-cut Hamiltonians in diagonal form, the edge count, '
        'number of max cut solutions, the max cut partitions and the eigenvalue'
        ' of the max cut states.'
    )
    artifact = wandb.Artifact(DSET_ARTIFACT_NAME, type='dataset',
                              description=description,
                              metadata={
                                  'keys': [
                                      'graph',
                                      'hamiltonian',
                                      'num_edges',
                                      'num_max_cut',
                                      'max_cut_bases',
                                      'max_cut_value'
                                  ],
                                  'num_vertices': numVertices,
                                  'edge_probability': edgeProbs,
                                  'num_graphs': num_graphs
                              })
    for p in edgeProbs:
        dset = ErdosRenyiDataset(None, True, numVertices, p, num_graphs)
        filename = f'p_{p:0.2f}.pt'
        with artifact.new_file(filename, mode='wb') as file:
            dset.save(file)
    run.log_artifact(artifact)
    run.finish()

def run_experiment(config=None, 
                   LOCAL_ARTIFACT_PATH:Optional[str]=None,
                   DSET_ARTIFACT_NAME:str='graph-datasets',
                   cuda_device:int=0)->None:
    dev = f'cuda:{cuda_device}'
    with wandb.init(config=config) as run:
        config = wandb.config

        N:int = config.num_iterations
        n:int = config.num_vertices
        m:int = config.num_layers
        p:float = config.edge_probability
        use_scheduler:bool = config.use_scheduler
        use_hidden_units:bool = config.use_hidden_units
        search_type:Literal['QDARTS','RhoDARTS'] = config.search_type
        lr:float = config.learning_rate
        Tmax:int = config.CAS_Tmax
        angPenaltyStr:float|None = config.angle_penalty_str
        entPenaltyStr:float|None = config.entropy_penalty_str
        schedStartVal:float|None = config.schedule_start_point
        schedStartTime:int|None = config.schedule_start_time
        schedDuration:Literal['full','half']|None = config.schedule_duration
        numOscillations:int|None = config.num_oscillations
        k:int = config.num_hidden_units
        tau:float|None = config.gumbel_temp

        if entPenaltyStr is None:
            entPenaltyStr = 0.0
        if angPenaltyStr is None:
            angPenaltyStr = 0.0

        psi0 = torch.ones(2**n,dtype=torch.complex64,device=dev)/(2**(n/2))

        if use_scheduler:
            entropySchedule = getEntropySchedule(N, schedStartVal, 
                                                 schedStartTime,
                                                 schedDuration, 
                                                 numOscillations).to(device=dev)
        else:
            entropySchedule = torch.zeros(N,device=dev)
        
        if search_type == 'QDARTS':
            model = QAS(n, m, k, QDARTS,psi0=psi0,
                             gumbel_temp=tau).to(device=dev)
        elif search_type == 'RhoDARTS':
            model = QAS(n, m, k, RhoDARTS, psi0=psi0).to(device=dev)

        softmax_temp = 1.0 if k == 0 else 10.0
        
        optimizer = torch.optim.Adam(model.parameters(), lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, Tmax)

        dset_name = f'p_{p:0.2f}.pt'
        dset_artifact:wandb.Artifact = run.use_artifact(f'{DSET_ARTIFACT_NAME}:latest',type='dataset')
        dset_dir = dset_artifact.download(root=LOCAL_ARTIFACT_PATH,path_prefix=dset_name)

        graphData = ErdosRenyiDataset(os.path.join(dset_dir,dset_name))[0]
        
        G = graphData['graph']
        H = graphData['hamiltonian'].to(device=dev)
        num_edges = graphData['num_edges']
        num_max_cut = graphData['num_max_cut']
        bases = graphData['max_cut_bases']
        max_cut_value = graphData['max_cut_value']

        gt_image = make_max_cut_plot(G, bases, None, 'Ground Truth Max Cuts')
        wandb.log({
            'true_max_cuts_img': wandb.Image(gt_image,
                                             caption='Ground Truth Max Cuts')
            },commit=False)

        for i in range(N):
            optimizer.zero_grad()

            qs, angles, probs = model(softmax_temperature=softmax_temp)
            loss_val, metrics = loss(qs, angles, probs, H, num_edges, i, 
                                     entropySchedule, entPenaltyStr, 
                                     angPenaltyStr)
            
            loss_val.backward()
            optimizer.step()
            scheduler.step()

            wandb.log({
                'loss':loss_val.item(),
                'energy':metrics[0].item(),
                'entropy':metrics[1].item(),
                'angle_penalty':metrics[2].item()
            },commit=(i!=N-1))
        
        dist = getZDistribution(qs)
        found_probs, found_bases = torch.topk(dist, num_max_cut)
        found_img = make_max_cut_plot(G, found_bases, found_probs, 'Found Cuts')
        wandb.log({
            'found_max_cuts_img': wandb.Image(found_img,
                                              caption='Found Max Cuts'),
        })


if __name__ == '__main__':
    # Default Expermient Settings
    # N = 2_000
    # numVertices = 10
    # edgeProbability = 0.5
    # use_scheduler = True
    # search_type = 'RhoDarts'
    import argparse
    import matplotlib
    matplotlib.use('agg')

    parser = argparse.ArgumentParser(description='Hyperparameter Sweep & Dataset Generation for the MaxCut Problem')
    
    parser.add_argument('-g', '--generate-dataset', action='store_true',
                        help='Flag to generate dataset instead of running hyperparameter sweep')
    
    # Arguments for dataset generation
    parser.add_argument('-n', '--num-vertices', type=int, default=10, help='Number of vertices in the graphs')
    parser.add_argument('-p', '--edge-probs', type=float, nargs='+', default=[0.5],
                        help='List of edge probabilities for the graphs')
    parser.add_argument('-G', '--num-graphs', type=int, default=100,
                        help='Number of graphs to generate')
    parser.add_argument('--project-name', type=str, required=True,
                        help='WANDB project name')
    parser.add_argument('--dataset-artifact-name', type=str, default='graph-datasets',
                        help='WANDB dataset artifact name')
    
    # Arguments for hyperparameter sweep
    parser.add_argument('-N', '--num-iterations', type=int, default=2000, help='Number of training iterations')
    parser.add_argument('-m', '--num-layers', type=int, default=5, help='Number of layers in the model')
    # parser.add_argument('-S', '--use-scheduler', action='store_true', help='Use entropy scheduler')
    # parser.add_argument('-k','--use-hidden-units', action='store_true', help='Use hidden units')
    
    # Search type arguments
    # parser.add_argument('-r', '--rho-darts', action='store_true', help='Use RhoDARTS search type')
    # parser.add_argument('-q', '--qdarts', action='store_true', help='Use QDARTS search type')

    parser.add_argument('--local-artifact-path', type=str, default='./artifacts',
                        help='Local directory to store WANDB artifacts')
    parser.add_argument('-R', '--num-runs', type=int, default=10, help='Number of runs for the sweep')

    parser.add_argument('-c','--cuda-device', type=int, default=0, help='Cuda device id you want to run the job on')
    
    args = parser.parse_args()
    
    # Ensure only one search type is selected
    # if args.rho_darts and args.qdarts:
    #     parser.error("Cannot specify both -r (RhoDARTS) and -q (QDARTS). Choose one.")
    
    # search_type = 'RhoDARTS' if args.rho_darts else 'QDARTS' if args.qdarts else 'RhoDARTS'
    

    if args.generate_dataset:
        wandb.login()
        make_dataset(numVertices=args.num_vertices, 
                     edgeProbs=args.edge_probs, 
                     num_graphs=args.num_graphs, 
                     PROJECT_NAME=args.project_name, 
                     DSET_ARTIFACT_NAME=args.dataset_artifact_name)
    else:
        import multiprocessing as mp
        from itertools import product
        
        def worker(search_type, use_scheduler, use_hidden_units):
            wandb.login()
            configure_sweep(N=args.num_iterations, 
                            numVertices=args.num_vertices, 
                            edgeProbability=args.edge_probs[0],
                            numLayers=args.num_layers, 
                            PROJECT_NAME=args.project_name, 
                            LOCAL_ARTIFACT_PATH=args.local_artifact_path, 
                            NUM_RUNS=args.num_runs, 
                            DSET_ARTIFACT_NAME=args.dataset_artifact_name,
                            cuda_device=args.cuda_device,
                            search_type=search_type,
                            use_scheduler=use_scheduler,
                            use_hidden_units=use_hidden_units)
        
        processes = [
            mp.Process(target=worker,kwargs={
                'search_type':st, 
                'use_scheduler':sched, 
                'use_hidden_units':hu
                }) for st,sched,hu in product(['RhoDARTS','QDARTS'],[True,False],[True,False])
        ]

        for p in processes:
            p.start()
        for p in processes:
            p.join()

