import torch
import wandb

import os
import multiprocessing as mp
from typing import Any, Literal
from itertools import product, chain
from functools import partial

from qtorch.unitaries import rx,ry,rz,cnot
from qtorch.quantumstate.measurements import getZDistribution,getExpectation
from qtorch.noise import BitPhaseFlipChannel, DepolarizingChannel
from qtorch.config import updateConfig

from ..penalty_scheduler import oscillating_scheduler
from ..qas import QAS, RhoDARTS, QDARTS, InputGen

from .data import ErdosRenyiDataset
from .plots import make_max_cut_plot, quantum_circuit_plot
from .loss import loss, qdarts_angle_loss, qdarts_arch_loss
from .utils import getEntropySchedule

def run_final_circuit(model:QAS, psi0:torch.Tensor)->torch.Tensor:
    with torch.no_grad():
        thetas,logits = model.input_gen()
        probs = torch.softmax(logits,dim=-1)
        n,m = model.search.num_qubits, model.search.num_layers
        circuit = probs.argmax(dim=-1).cpu()
        noise_model = model.search.noise_model

        if noise_model is None:
            qs = psi0.clone()
        else:
            qs = torch.outer(psi0, psi0.conj())
        for layer in range(m):
            for target_qubit in range(n):
                G=circuit[layer,target_qubit].item()
                if G == 0:
                    continue
                elif G < 4:
                    qs = [rx,ry,rz][G-1](qs, thetas[layer,target_qubit],
                                            n, target_qubit)
                else:
                    control_qubit = G-4 if G-4<target_qubit else G-3
                    qs = cnot(qs, n, target_qubit, control_qubit)
            if noise_model is not None:
                qs = noise_model(qs.unsqueeze(0))[0]
    return qs, circuit, thetas

def upload_circuits(edgeProb:float,
                    run_params:dict[str|Any],
                    final_architectures:torch.Tensor,
                    final_angles:torch.Tensor,
                    final_states:torch.Tensor,
                    PROJECT_NAME:str)->None:
    run = wandb.init(project=PROJECT_NAME, job_type='upload-circuits')
    circuit_artifact = wandb.Artifact(
        (f'{run_params["search_type"]}-MaxCut-Circuits-p-{edgeProb:0.2f}-'
        f'{"With" if run_params["use_hidden_units"] else "Without"}-Hidden-'
        'Units'
        + (f'-{run_params["noise_model"]}-p-{run_params["noise_prob"]:0.2f}' 
           if run_params['noise_model'] != 'None' else '')
        ),
        type='circuit-data',
        description=f'The circuits found by {run_params["search_type"]} to'
        ' find the max cuts along with the corresponding states produced.',
        metadata={
            'num_vertices': final_architectures.shape[-1],
            'edge_probability':edgeProb,
            'num_layers': run_params['num_layers']
        }
    )

    with circuit_artifact.new_file(f'arch_p_{edgeProb:0.2f}.pt', 
                                    mode='wb') as file:
        torch.save(final_architectures, file)
    with circuit_artifact.new_file(f'angles_p_{edgeProb:0.2f}.pt', 
                                    mode='wb') as file:
        torch.save(final_angles, file)
    with circuit_artifact.new_file(f'states_p_{edgeProb:0.2f}.pt', 
                                    mode='wb') as file:
        torch.save(final_states, file)
    run.log_artifact(circuit_artifact)
    run.finish()

def run_experiment(edgeProb:float,
                   run_params:dict[str,Any],
                   PROJECT_NAME:str,
                   local_artifact_dir:str,
                   cuda_device:int=0)->None:
    dev = f'cuda:{cuda_device}'
    dset = ErdosRenyiDataset(os.path.join(local_artifact_dir,
                                            f'p_{edgeProb:0.2f}.pt'))
    numGraphs = len(dset)
    n = dset[0]['graph'].shape[0]
    psi0 = torch.ones(2**n,dtype=torch.complex64,device=dev)/(2**(n/2))

    if run_params['noise_model'] == 'BitPhaseFlip':
            noise_channel = partial(BitPhaseFlipChannel, num_qubits=n,
                                    prob_x=run_params['noise_prob'],
                                    prob_z=run_params['noise_prob'])
    elif run_params['noise_model'] == 'Depolarizing':
        assert run_params['search_type'] == 'RhoDARTS', ('Depolarizing '
        'noise requires density matrix simulation, i.e. RhoDARTS.')
        noise_channel = partial(DepolarizingChannel, num_qubits=n, 
                                prob=run_params['noise_prob'])
    else:
        noise_channel = None

    final_architectures = torch.empty([numGraphs, run_params['num_layers'], n],
                                dtype=torch.int32)
    final_angles = torch.empty([numGraphs, run_params['num_layers'], n])
    if noise_channel is None:
        noise_model = None
        final_states = torch.empty([numGraphs,2**n], dtype=psi0.dtype)
    else:
        noise_model = lambda qs: noise_channel(qs[0]).unsqueeze(0)
        final_states = torch.empty([numGraphs,2**n,2**n],dtype=psi0.dtype)

    for g_i in range(numGraphs):
        graphData = dset[g_i]
        G = graphData['graph']
        H = graphData['hamiltonian'].to(device=dev)
        num_edges = graphData['num_edges']
        num_max_cut = graphData['num_max_cut']
        bases = graphData['max_cut_bases']
        max_cut_value = graphData['max_cut_value']

        config = run_params | dict(
            edge_probability=edgeProb,
            num_max_cut = num_max_cut,
            min_energy = max_cut_value
        )

        N = run_params['num_iterations']
        entropySchedule = getEntropySchedule(
            N,
            run_params['schedule_start_point'],
            run_params['schedule_start_time'],
            run_params['schedule_duration'],
            run_params['num_oscillations']
        )
        entPenaltyStr = run_params['entropy_penalty_str']
        angPenaltyStr = run_params['angle_penalty_str']

        search_type = run_params['search_type']
        num_iter = run_params['num_angle_iter']
        if search_type == 'RhoDARTS':
            model = QAS(n, run_params['num_layers'],
                                run_params['num_hidden_units'], RhoDARTS,
                                psi0=psi0, noise_model = noise_model
                                ).to(device=dev)
            optimizer = torch.optim.Adam(model.parameters(),
                                         run_params['learning_rate'])
        elif search_type == 'QDARTS':
            model = QAS(n, run_params['num_layers'],
                                run_params['num_hidden_units'], QDARTS,
                                psi0=psi0,
                                gumbel_temp=run_params['gumbel_temp'],
                                noise_model = noise_model
                                ).to(device=dev)
            input_gen = InputGen(n, run_params['num_layers'], 
                                 run_params['num_hidden_units']).to(device=dev)
            qdarts = QDARTS(n, run_params['num_layers'], 
                            run_params['gumbel_temp'], psi0=psi0,
                            noise_model=noise_model).to(device=dev)
            optimizer = torch.optim.Adam(input_gen.logit_recipe, 
                                         run_params['learning_rate'])
            angles_optimizer = torch.optim.Adam([input_gen.thetas],
                                                run_params['learning_rate'])
            angle_loss = partial(qdarts_angle_loss, hamiltonian=H,
                                 energy_normalization_factor=num_edges,
                                 angle_penalty_str=angPenaltyStr)
        
        if run_params['num_hidden_units'] > 0:
            softmax_temp = 10.0
        else:
            softmax_temp = 1.0

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, run_params['CAS_Tmax'])

        with wandb.init(
            project=PROJECT_NAME,job_type='main-result',
            config=config,
            name=(f'main-{run_params["search_type"]}-p={edgeProb:0.2f}'
                  f'{"-hu" if run_params["use_hidden_units"] else ""}'
                  f'-{run_params["noise_model"] if noise_channel is not None else ""}'
                  f'-{g_i}')
        ) as run:
            for i in range(N):
                optimizer.zero_grad()

                if search_type == 'RhoDARTS':
                    qs, angles, probs = model(softmax_temperature=softmax_temp)
                    loss_val, metrics = loss(qs, angles, probs, H, num_edges, i, 
                                            entropySchedule, entPenaltyStr, 
                                            angPenaltyStr)
                else:
                    angles,logits = input_gen()
                    qs, sampled_gates = qdarts(logits, angles, angles_optimizer, num_iter,
                                angle_loss, softmax_temp)
                    probs = (logits/softmax_temp).softmax(dim=-1)
                    loss_val, metrics = qdarts_arch_loss(qs, probs, H, 
                                                         num_edges, i, 
                                                         entropySchedule, 
                                                         entPenaltyStr)
                
                loss_val.backward()
                optimizer.step()
                scheduler.step()

                wandb.log({
                    'loss':loss_val.item(),
                    'energy':metrics[0].item(),
                    'entropy':metrics[1].item(),
                    # 'angle_penalty':metrics[2].item()
                },commit=(i!=N-1))
            phi, circuit, thetas = run_final_circuit(model, psi0)
            dist = getZDistribution(phi)
            energy = getExpectation(phi, H)
            found_probs, found_bases = torch.topk(dist, num_max_cut)
            
            gt_image = make_max_cut_plot(G, bases, None, 
                                        'Ground Truth Max Cuts')
            found_img = make_max_cut_plot(G, found_bases, found_probs, 
                                            'Found Cuts')
            circuit_img = quantum_circuit_plot(circuit, thetas)


            wandb.log({
                'true_max_cuts_img': wandb.Image(gt_image,
                                                    caption='Ground Truth Max '
                                                    'Cuts'),
                'found_max_cuts_img': wandb.Image(found_img, 
                                                    caption='Found Max Cuts'),
                'circuit_img': wandb.Image(circuit_img,
                                           caption='Found Quantum Circuit'),
                'energy': energy
            })

        final_architectures[g_i] = circuit.cpu()
        final_angles[g_i] = model.input_gen.thetas.cpu()
        final_states[g_i] = phi.cpu()
    
    upload_circuits(edgeProb, run_params, final_architectures, final_angles,
                    final_states, PROJECT_NAME)

if __name__ == '__main__':
    import matplotlib
    
    import argparse
    
    matplotlib.use('agg')
    updateConfig('skipValidation',True)

    wandb.login()

    numVertices = 10
    edgeProbs = [0.25, 0.5, 0.75]
    RhoDARTS_params = dict(
        search_type='RhoDARTS',
        num_layers=15,
        num_iterations=250,
        num_angle_iter=10,
        learning_rate=0.1,
        CAS_Tmax = 25,
        angle_penalty_str = 0.01,
        use_scheduler=True,
        entropy_penalty_str = 0.1,
        schedule_start_point = 0.0,
        schedule_start_time = 0,
        schedule_duration = 'half',
        num_oscillations = 1,
    )
    QDARTS_params = RhoDARTS_params.copy() | dict(
        search_type='QDARTS',
        gumbel_temp=0.05
    )

    parser = argparse.ArgumentParser(description='Script to run the macro-'
                                     'search experiment for the max-cut '
                                     'problem.')
    
    parser.add_argument('--project-name', type=str, required=True,
                        help='WANDB project name')
    parser.add_argument('--dataset-artifact-name', type=str, 
                        default='graph-datasets',
                        help='WANDB dataset artifact name')
    parser.add_argument('--local-artifact-path', type=str, 
                        default='./artifacts',
                        help='Local directory to store WANDB artifacts')
    parser.add_argument('-c','--cuda-device', type=int, default=0, 
                        help='Cuda device id you want to run the job on')
    parser.add_argument('-n', '--noisy-expt', action='store_true', 
                        help='Whether to run the noise experiments')
    parser.add_argument('-m', '--multiprocessing', action='store_true',
                        help='Whether to use multiprocessing to parallelize the'
                        ' runs')
    parser.add_argument('-p', '--num-processes', type=int, default=1,
                        help='Number of processes in the pool')

    args = parser.parse_args()

    load_dataset_run = wandb.init(project=args.project_name,
                                  job_type='load_data')
    artifact = load_dataset_run.use_artifact(f'{args.dataset_artifact_name}:'
                                             'latest')
    artifact_dir = artifact.download(root=args.local_artifact_path)
    load_dataset_run.finish()

    def worker(search_type:Literal['RhoDARTS','QDARTS'],edgeProb:float, 
               useHiddenUnits:bool, noise_model:str='None', 
               noise_prob:float=0.0):
        run_experiment(edgeProb, 
                       (RhoDARTS_params if search_type=='RhoDARTS' else QDARTS_params)
                       | {
                           'use_hidden_units': useHiddenUnits,
                           'num_hidden_units': 26 if useHiddenUnits else 0,
                           'noise_model': noise_model,
                           'noise_prob': noise_prob
                        },
                        args.project_name,
                        artifact_dir,
                        args.cuda_device)

    if not args.noisy_expt:
        worker_args = product(['RhoDARTS', 'QDARTS'], 
                              edgeProbs,
                              [False,True])
        if not args.multiprocessing:
            for w_args in worker_args:
                worker(*w_args)
        else:
            with mp.Pool(min(args.num_processes, mp.cpu_count())) as pool:
                pool.starmap(worker, worker_args)

    else:
        noise_probs = [0.01, 0.1, 0.25, 0.5]
        noise_models = ['BitPhaseFlip', 'Depolarizing']
        worker_args = chain(
            # product(['RhoDARTS'],edgeProbs, [False, True], noise_models, 
            #         noise_probs),
            product(['QDARTS'], edgeProbs, [False,True], ['BitPhaseFlip'],
                    noise_probs)
        )
        if not args.multiprocessing:
            for w_args in worker_args:
                worker(*w_args)
        else:
            with mp.Pool(min(args.num_processes, mp.cpu_count())) as pool:
                pool.starmap(worker, worker_args)
        
