import torch
import wandb

import os
from typing import Any, Literal
from itertools import product
from functools import partial
import multiprocessing as mp

from qtorch.unitaries import rx,ry,rz,cnot
from qtorch.quantumstate.measurements import getZDistribution,getExpectation
from qtorch.config import updateConfig

from ..penalty_scheduler import oscillating_scheduler
from ..qas_micro import QASMicro, QDARTSMicro, RhoDARTSMicro, InputGenMicro

from .data import ErdosRenyiDataset
from .plots import (make_max_cut_plot, quantum_super_circuit_plot, 
                    quantum_circuit_plot)
from .loss import loss, qdarts_angle_loss, qdarts_arch_loss
from .utils import getEntropySchedule, graphToSuperCircuitStructure

def run_final_circuit(model:QASMicro, psi0:torch.Tensor
                      )->tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    with torch.no_grad():
        angles, logits = model.input_gen()
        probs = torch.softmax(logits,dim=-1)
        n = model.search.total_qubits
        nu = model.search.num_subcircuit_qubits
        m = model.search.num_layers
        sc_struct = model.search.super_circuit_structure
        num_subcircuits = sc_struct.shape[0]

        subcircuit = probs.argmax(dim=-1)

        phi = psi0.clone()
        for i in range(num_subcircuits):
            for layer in range(m):
                for t_id in range(nu):
                    target_qubit = sc_struct[i,t_id]
                    G = subcircuit[layer,t_id].item()
                    if G == 0:
                        continue
                    elif G < 4:
                        phi = [rx,ry,rz][G-1](phi, angles[i,layer,t_id], n, target_qubit)
                    else:
                        c_id = G-4 if  G-4 < t_id else G-3
                        control_qubit = sc_struct[i,c_id].item()
                        phi = cnot(phi, n, target_qubit, control_qubit)
        return phi, subcircuit, angles

def upload_circuits(edgeProb:float,
                    run_params:dict[str|Any],
                    final_architectures:torch.Tensor,
                    final_angles:torch.Tensor,
                    final_states:torch.Tensor,
                    PROJECT_NAME:str)->None:
    run = wandb.init(project=PROJECT_NAME, job_type='upload-circuits')
    circuit_artifact = wandb.Artifact(
        f'{run_params["search_type"]}-Micro-MaxCut-Circuits-p-{edgeProb:0.2f}-'
        f'{"With" if run_params["use_hidden_units"] else "Without"}-Hidden-'
        'Units',
        type='circuit-data',
        description=f'The circuits found by {run_params["search_type"]} '
        'MicroSearch to find the max cuts along with the corresponding states '
        'produced.',
        metadata={
            'num_vertices': final_architectures.shape[-1],
            'edge_probability':edgeProb,
            'num_layers': run_params['num_layers']
        }
    )

    with circuit_artifact.new_file(f'arch_p_{edgeProb:0.2f}.pt', 
                                    mode='wb') as file:
        torch.save(final_architectures, file)
    with circuit_artifact.new_file(f'angles_p_{edgeProb:0.2f}.pt', 
                                    mode='wb') as file:
        torch.save(final_angles, file)
    with circuit_artifact.new_file(f'states_p_{edgeProb:0.2f}.pt', 
                                    mode='wb') as file:
        torch.save(final_states, file)
    run.log_artifact(circuit_artifact)
    run.finish()

def run_experiment(edgeProb:float,
                   run_params:dict[str,Any],
                   PROJECT_NAME:str,
                   local_artifact_dir:str,
                   cuda_device:int=0)->None:
    dev = f'cuda:{cuda_device}'
    dset = ErdosRenyiDataset(os.path.join(local_artifact_dir,
                                          f'p_{edgeProb:0.2f}.pt'))
    numGraphs = len(dset)
    n = dset[0]['graph'].shape[0]
    max_edges = max([G['num_edges'] for G in dset])
    psi0 = torch.ones(2**n, dtype=torch.complex64, device=dev)/(2**(n/2))

    final_architectures = torch.empty([numGraphs,
                                       run_params['num_layers'],
                                       2],
                                       dtype=torch.int32)
    final_angles = torch.full([numGraphs, 
                               max_edges,
                               run_params['num_layers'],
                               2],
                               torch.nan)
    final_states = torch.empty([numGraphs, 2**n], dtype=torch.complex64)

    for g_i in range(numGraphs):
        graphData = dset[g_i]
        G = graphData['graph']
        H = graphData['hamiltonian'].to(device=dev)
        num_edges = graphData['num_edges']
        num_max_cut = graphData['num_max_cut']
        bases = graphData['max_cut_bases']
        max_cut_value = graphData['max_cut_value']

        config = run_params | dict(
            edge_probability=edgeProb,
            num_max_cut = num_max_cut,
            min_energy = max_cut_value
        )

        sc_struct = graphToSuperCircuitStructure(G).to(device=dev)

        N = run_params['num_iterations']
        num_iter = run_params['num_angle_iter']
        entropySchedule = getEntropySchedule(
            N,
            run_params['schedule_start_point'],
            run_params['schedule_start_time'],
            run_params['schedule_duration'],
            run_params['num_oscillations']
        )
        entPenaltyStr = run_params['entropy_penalty_str']
        angPenaltyStr = run_params['angle_penalty_str']

        search_type = run_params['search_type']
        if search_type == 'RhoDARTS':
            model = QASMicro(n, 2, sc_struct, run_params['num_layers'],
                             run_params['num_hidden_units'], RhoDARTSMicro,
                             psi0
                            ).to(device=dev)
            optimizer = torch.optim.Adam(model.parameters(),
                                         run_params['learning_rate'])
        elif search_type == 'QDARTS':
            model = QASMicro(n, 2, sc_struct, run_params['num_layers'],
                             run_params['num_hidden_units'], QDARTSMicro,
                             psi0, gumbel_temp=run_params['gumbel_temp']
                            ).to(device=dev)
            input_gen = InputGenMicro(2, sc_struct.shape[0], 
                                      run_params['num_layers'], 
                                      run_params['num_hidden_units']
                                      ).to(device=dev)
            qdarts = QDARTSMicro(n, 2, run_params['num_layers'], sc_struct,
                                 run_params['gumbel_temp'], psi0=psi0
                                 ).to(device=dev)
            optimizer = torch.optim.Adam(input_gen.logit_recipe, 
                                         run_params['learning_rate'])
            angles_optimizer = torch.optim.Adam([input_gen.thetas],
                                                run_params['learning_rate'])
            angle_loss = partial(qdarts_angle_loss, hamiltonian=H,
                                 energy_normalization_factor=num_edges,
                                 angle_penalty_str=angPenaltyStr,
                                 batched=False)
        
        if run_params['num_hidden_units'] > 0:
            softmax_temp = 10.0
        else:
            softmax_temp = 1.0

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, run_params['CAS_Tmax'])

        with wandb.init(
            project=PROJECT_NAME,job_type='main-result-micro',
            config=config,
            name=f'main-{run_params["search_type"]}-micro-p={edgeProb:0.2f}{"-hu" if run_params["use_hidden_units"] else ""}-{g_i}'
        ) as run:
            for i in range(N):
                optimizer.zero_grad()

                if search_type == 'RhoDARTS':
                    qs, angles, probs = model(softmax_temperature=softmax_temp)
                    loss_val, metrics = loss(qs, angles, probs, H, num_edges, i, 
                                            entropySchedule, entPenaltyStr, 
                                            angPenaltyStr)
                else:
                    angles,logits = input_gen()
                    qs = qdarts(logits, angles, angles_optimizer, num_iter,
                                angle_loss, softmax_temp)
                    probs = (logits/softmax_temp).softmax(dim=-1)
                    loss_val, metrics = qdarts_arch_loss(qs, probs, H, 
                                                         num_edges, i, 
                                                         entropySchedule, 
                                                         entPenaltyStr,False)
                
                loss_val.backward()
                optimizer.step()
                scheduler.step()

                wandb.log({
                    'loss':loss_val.item(),
                    'energy':metrics[0].item(),
                    'entropy':metrics[1].item(),
                    # 'angle_penalty':metrics[2].item()
                },commit=(i!=N-1))
            
            phi, subcircuit, thetas = run_final_circuit(model, psi0)
            dist = getZDistribution(phi)
            energy = getExpectation(phi, H)
            found_probs, found_bases = torch.topk(dist, num_max_cut)
            
            gt_image = make_max_cut_plot(G, bases, None, 
                                        'Ground Truth Max Cuts')
            found_img = make_max_cut_plot(G, found_bases, found_probs, 
                                            'Found Cuts')
            circuit_img = quantum_super_circuit_plot(subcircuit, thetas, 
                                                     sc_struct, n)
            subcircuit_img = quantum_circuit_plot(subcircuit,thetas[0],plot_barriers=False)


            wandb.log({
                'true_max_cuts_img': wandb.Image(gt_image,
                                                    caption='Ground Truth Max '
                                                    'Cuts'),
                'found_max_cuts_img': wandb.Image(found_img, 
                                                    caption='Found Max Cuts'),
                'circuit_img': wandb.Image(circuit_img,
                                           caption='Found Quantum Super Circuit'),
                'subcircuit_img': wandb.Image(subcircuit_img,
                                              caption='Found Sub-circuit'),
                'energy': energy
            })

        final_architectures[g_i] = subcircuit.cpu()
        final_angles[g_i,:num_edges] = model.input_gen.thetas.cpu()
        final_states[g_i] = phi.cpu()
    
    upload_circuits(edgeProb, run_params, final_architectures,
                    final_angles, final_states, PROJECT_NAME)

if __name__ == '__main__':
    import matplotlib
    import argparse
    
    matplotlib.use('agg')
    updateConfig('skipValidation', True)

    wandb.login()

    numVertices = 10
    edgeProbs = [0.25, 0.5, 0.75]
    RhoDARTS_params = dict(
        search_type='RhoDARTS',
        num_layers=3,
        num_iterations=1000,
        num_angle_iter=10,
        learning_rate=0.1,
        CAS_Tmax = 100,
        angle_penalty_str = 0.01,
        use_scheduler=True,
        entropy_penalty_str = 0.1,
        schedule_start_point = 0.0,
        schedule_start_time = 0,
        schedule_duration = 'half',
        num_oscillations = 1,
    )

    QDARTS_params = RhoDARTS_params.copy() | {
        'search_type': 'QDARTS',
        'gumbel_temp': 0.05,
    }

    parser = argparse.ArgumentParser(description='Script to run the micro-'
                                     'search experiment for the max-cut '
                                     'problem.')
    
    parser.add_argument('--project-name', type=str, required=True,
                        help='WANDB project name')
    parser.add_argument('--dataset-artifact-name', type=str, 
                        default='graph-datasets',
                        help='WANDB dataset artifact name')
    parser.add_argument('--local-artifact-path', type=str, 
                        default='./artifacts',
                        help='Local directory to store WANDB artifacts')
    parser.add_argument('-c','--cuda-device', type=int, default=0, 
                        help='Cuda device id you want to run the job on')
    parser.add_argument('-m', '--multiprocessing', action='store_true',
                        help='Whether to use multiprocessing to parallelize the'
                        ' runs')
    parser.add_argument('-p', '--num-processes', type=int, default=1,
                        help='Number of processes in the pool')

    args = parser.parse_args()

    load_dataset_run = wandb.init(project=args.project_name,
                                  job_type='load_data')
    artifact = load_dataset_run.use_artifact(f'{args.dataset_artifact_name}:'
                                             'latest')
    artifact_dir = artifact.download(root=args.local_artifact_path)
    load_dataset_run.finish()

    def worker(search_type:Literal['RhoDARTS','QDARTS'], p:float, 
               useHiddenUnits:bool):
        run_experiment(p,
                       (RhoDARTS_params if search_type == 'RhoDARTS' else QDARTS_params)
                       | {
                           'use_hidden_units': useHiddenUnits,
                           'num_hidden_units': 10 if useHiddenUnits else 0,
                        },
                        args.project_name,
                        artifact_dir,
                        args.cuda_device)
    
    worker_args = product(['RhoDARTS','QDARTS'], 
                          edgeProbs,
                          [False,True])
    if not args.multiprocessing:
        for w_args in worker_args:
            worker(*w_args)
    else:
        with mp.Pool(min(args.num_processes, mp.cpu_count())) as pool:
            pool.starmap(worker, worker_args)
