import wandb
import torch
from qiskit import QuantumCircuit

import argparse
import math
import multiprocessing as mp
from functools import partial
from itertools import product
from typing import Any
from copy import deepcopy

from qtorch import updateConfig, RTYPE
from qtorch.quantumstate.measurements import getExpectation

from examples.qas.max_cut.utils import getEntropySchedule
from examples.qas.qas import QAS, QDARTS, RhoDARTS, InputGen

from .utils import * #data_dict,hartree_fock_state,load_hamiltonian_data
from .loss import loss, qdarts_angle_loss, qdarts_arch_loss

def run_experiment(params:dict[str,Any],
                   run_name:str,
                   cuda_device:int=0)->None:
    run = wandb.init(project=params["project_name"], 
                     name=run_name, 
                     config=params)
    run.define_metric('loss',step_metric='run_frac')
    run.define_metric('entropy',step_metric='run_frac')
    run.define_metric('energy_error',step_metric='run_frac')
    run.define_metric('energy',step_metric='run_frac')
    run.define_metric('angle_penalty',step_metric='run_frac')
    dev = f'cuda:{cuda_device}'
    ham, min_eig = load_hamiltonian_data(params['molecule'],
                                         params['num_qubits'],
                                         params['transform'])
    ham = ham.to(dev)
    num_qubits = params['num_qubits']
    num_electrons = electron_count(params['molecule'], num_qubits)
    
    search_type = params['search_type']
    N = params['num_iterations']
    entropySchedule = getEntropySchedule(
        N,
        params['schedule_start_point'],
        params['schedule_start_time'],
        params['schedule_duration'],
        params['num_oscillations']
    )
    entPenaltyStr = params['entropy_penalty_str']
    angPenaltyStr = params['angle_penalty_str']

    if params['num_hidden_units'] > 0:
        softmax_temp = 10.0
    else:
        softmax_temp = 1.0

    psi0 = hartree_fock_state(num_qubits, num_electrons, 
                              params['transform']).to(dev)
    
    hf_energy = getExpectation(psi0, ham).abs()
    run_prefix = f'{params["molecule"]}-{num_qubits}-{params["transform"]}'
    if search_type == 'QDARTS':
        run_prefix += '-q'

    num_layers = params['num_layers']
    # run_name = f'shifted-{run_prefix}-{num_layers}-layers'

    # best_angles = torch.empty([num_qubits,num_layers],dtype=RTYPE,device=dev)
    best_angles:torch.Tensor = None
    least_energy = hf_energy
    
    if search_type == 'RhoDARTS':
        model = QAS(num_qubits, num_layers,
                    params['num_hidden_units'], RhoDARTS,
                    psi0=psi0).to(device=dev)
        optimizer = torch.optim.Adam(model.parameters(),
                                        params['learning_rate'])
    elif search_type == 'QDARTS':
        num_iter = params['num_angle_iter']
        model = QAS(num_qubits, num_layers,
                    params['num_hidden_units'], QDARTS,
                    psi0=psi0, gumbel_temp=params['gumbel_temp'],
                    ).to(device=dev)
        input_gen = InputGen(num_qubits, num_layers, 
                                params['num_hidden_units']).to(device=dev)
        qdarts = QDARTS(num_qubits, num_layers, 
                        params['gumbel_temp'], psi0=psi0).to(device=dev)
        optimizer = torch.optim.Adam(input_gen.logit_recipe, 
                                        params['learning_rate'])
        angles_optimizer = torch.optim.Adam([input_gen.thetas],
                                            params['learning_rate'])
        angle_loss = partial(qdarts_angle_loss, hamiltonian=ham,
                                energy_normalization_factor=hf_energy,
                                angle_penalty_str=angPenaltyStr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, params['CAS_Tmax'])
    
    # with wandb.init(project=PROJECT_NAME, name=run_name,
    #                 config=params) as run:
    for epoch in range(N):
        optimizer.zero_grad()

        if search_type == 'RhoDARTS':
            qs, angles, probs = model(softmax_temperature=softmax_temp)
            loss_val, metrics = loss(qs, angles, probs, ham, -hf_energy, epoch, 
                                    entropySchedule, entPenaltyStr, 
                                    angPenaltyStr)
        else:
            angles,logits = input_gen()
            qs, sampled_gates = qdarts(logits, angles, angles_optimizer, num_iter,
                        angle_loss, softmax_temp)
            probs = (logits/softmax_temp).softmax(dim=-1)
            loss_val, metrics = qdarts_arch_loss(qs, probs, ham, 
                                                    -hf_energy, epoch, 
                                                    entropySchedule, 
                                                    entPenaltyStr)
        energy = metrics[0]
        if energy < least_energy:
            best_angles = angles.clone()
        loss_val.backward()
        optimizer.step()
        scheduler.step()

        run.log({
            'loss': loss_val.item(),
            'energy': metrics[0].item(),
            'entropy':metrics[1].item(),
            'energy_error': (energy_error:=(metrics[0] - min_eig).abs().item()),
            'angle_penalty': (metrics[2] if search_type == "RhoDARTS" 
                                else float('nan')),
            'run_frac': epoch/N,
        })
        # Early stopping
        if energy_error <= 1e-5:
            break
    artifact = wandb.Artifact(f'{run_name}',type='circuit-data',
                                )
    with artifact.new_file('best_angles.pt', mode='wb') as file:
        torch.save(best_angles.detach().cpu(),file)
    with artifact.new_file('architecture.pt', mode='wb') as file:
        circuit = probs.argmax(dim=-1).cpu().detach()
        torch.save(circuit, file)
    run.log_artifact(artifact)

    # Number of trainable params in final circuit
    num_angles = circuit[circuit < 4].clamp(0,1).sum().item()
    # Depth calculation via qiskit
    qc = QuantumCircuit(num_qubits)
    for i,j in product(range(num_layers),range(num_qubits)):
        G = circuit[i,j].item()
        if G == 0:
            continue
        elif G == 1:
            qc.x(j)
        elif G == 2:
            qc.y(j)
        elif G == 3:
            qc.z(j)
        elif G > 3:
            cbit = G-4 if G-4 < j else G-3
            qc.cx(cbit, j)
    circuit_depth = qc.depth()

    run.log({
        'num_params': num_angles,
        'depth': circuit_depth
    })
    run.finish()

def worker(params, run_name, cuda):
    run_experiment(params, run_name, cuda)

def launch_runs(args:argparse.Namespace)->None:
    # Search space
    learning_rates = [0.1, 0.01, 0.001]
    iterations = [1000, 5000, 10000]
    layer_factors = [2, 4, 8, 16]
    cas_tmultipliers = [0.1, 0.5, 0.75, 1.0]

    mol_cfgs = [
        # ('H2', 4, 'jw'),
        # ('LiH', 4, 'parity'),
        # ('LiH', 6, 'jw'),
        ('H2O', 8, 'jw')
    ]

    run_cfgs = []
    for (mol, nq, tr), lr, iters, lf, cas_tmult in product(
        mol_cfgs, learning_rates, iterations, layer_factors, cas_tmultipliers
    ):
        params = {
            "project_name": args.project_name,
            "learning_rate": lr,
            "num_iterations": iters,
            "layer_factor": lf,
            "CAS_Tmult": cas_tmult,
            "search_type": "RhoDARTS",
            "num_angle_iter": 10,
            "angle_penalty_str": 0.01,
            "use_scheduler": True,
            "entropy_penalty_str": 0.1,
            "schedule_start_point": 0.0,
            "schedule_start_time": 0,
            "schedule_duration": "half",
            "num_oscillations": 1,
            "use_hidden_units": True,
            "molecule": mol,
            "num_qubits": nq,
            "transform": tr,
            "num_hidden_units": 2 * (nq + 3),
            "num_layers": nq * lf,
            "CAS_Tmax": int(iters * cas_tmult),
        }
        run_name = f"{mol}-{nq}-{tr}-lr{lr}-iter{iters}-lf{lf}-cas{cas_tmult}"
        run_cfgs.append((params, run_name))
    
    procs = []
    num_proc = min(args.num_processes, mp.cpu_count())
    cuda = args.cuda_device
    jobs = [(params, run_name, cuda) for params, run_name in run_cfgs]

    num_proc = min(args.num_processes, mp.cpu_count())
    with mp.Pool(processes=num_proc) as pool:
        pool.starmap(worker, jobs)
    # for i, (params, run_name) in enumerate(run_cfgs):
    #     p = mp.Process(target=worker, args=(params, run_name, cuda))
    #     p.start()
    #     procs.append(p)
    #     # if len(procs) >= num_proc:
    #     #     for pp in procs:
    #     #         pp.join()
    #     #     procs = []
    # for pp in procs:
    #     pp.join()

            
 
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Sweep runner for VQE')
    parser.add_argument('--project-name', type=str, required=True,
                        help='WANDB project name')
    parser.add_argument('-c','--cuda-device', type=int, default=0,
                        help='CUDA device id')
    parser.add_argument('-m','--multiprocessing', action='store_true', 
                        help='Whether to use multiprocessing')
    parser.add_argument('-p','--num-processes', type=int, default=8,
                        help='Number of processes to spawn for the sweep')
    
    args = parser.parse_args()

    launch_runs(args)
