import torch
import wandb

import os
import multiprocessing as mp
from typing import Any, Literal
from itertools import product,chain,repeat
from functools import partial
import math

from qtorch.unitaries import rx,ry,rz,cnot
from qtorch.quantumstate.measurements import getZDistribution,getExpectation
from qtorch.quantumstate import fidelity, Statevector
from qtorch.noise import BitPhaseFlipChannel, DepolarizingChannel
from qtorch.config import updateConfig

from ..penalty_scheduler import oscillating_scheduler
from ..qas import QAS, RhoDARTS, QDARTS, InputGen

from ..max_cut.plots import quantum_circuit_plot


from .loss import loss, qdarts_arch_loss, qdarts_angle_loss
from .utils import GHZ_State, W_State, getEntropySchedule

def run_experiment(run_params:dict[str,Any],
                   PROJECT_NAME:str,
                   cuda_device:int=0)->None:
    dev = f'cuda:{cuda_device}'
    search_type = run_params['search_type']
    N = run_params['num_iterations']
    if search_type == 'QDARTS':
        num_iter = run_params['num_angle_iter']
        N //= num_iter
    Tmax = int(N * run_params['CAS_Tmult'])
    stop_threshold = run_params['stop_threshold']
    entropySchedule = getEntropySchedule(
        N,
        run_params['schedule_start_point'],
        run_params['schedule_start_time'],
        run_params['schedule_duration'],
        run_params['num_oscillations']
    )
    entPenaltyStr = run_params['entropy_penalty_str']
    angPenaltyStr = run_params['angle_penalty_str']
    
    if run_params['use_hidden_units']:
        softmax_temp = 10.0
    else:
        softmax_temp = 1.0

    final_gates = []
    final_angles = []
    final_states = []
    
    for num_qubits in range(2,run_params['max_qubits']+1):
        if run_params['entanglement_type'] == 'GHZ':
            ref_state = GHZ_State(num_qubits, dev)
        elif run_params['entanglement_type'] == 'W':
            ref_state = W_State(num_qubits, dev)

        if run_params['noise_model'] == 'BitPhaseFlip':
            noise_channel = partial(BitPhaseFlipChannel, num_qubits=num_qubits,
                                    prob_x=run_params['noise_prob'],
                                    prob_z=run_params['noise_prob'])
        elif run_params['noise_model'] == 'Depolarizing':
            assert search_type == 'RhoDARTS', ('Depolarizing '
            'noise requires density matrix simulation, i.e. RhoDARTS.')
            noise_channel = partial(DepolarizingChannel, num_qubits=num_qubits, 
                                    prob=run_params['noise_prob'])
        else:
            noise_channel = None
        
        if run_params['use_hidden_units']:
            num_hidden_units = 2*(num_qubits+3)
        else:
            num_hidden_units = 0

        
        if noise_channel is not None:
            if search_type == 'QDARTS':
                noise_model = lambda qs: noise_channel(qs[0]).unsqueeze(0)
            else:
                noise_model = noise_channel
        else:
            noise_model = None
        if run_params['entanglement_type'] == 'W':
            num_layers = run_params['w_layerFactor']*(num_qubits-1)
        else:
            num_layers = run_params['ghz_layers']
        if search_type == 'RhoDARTS':
            model = QAS(num_qubits, num_layers, num_hidden_units, RhoDARTS,
                        noise_model = noise_model
                        ).to(device=dev)
            optimizer = torch.optim.Adam(model.parameters(),
                                         run_params['learning_rate'])
        elif search_type == 'QDARTS':
            input_gen = InputGen(num_qubits, num_layers, num_hidden_units
                                 ).to(device=dev)
            qdarts = QDARTS(num_qubits, num_layers, run_params['gumbel_temp'],
                            noise_model=noise_model).to(device=dev)
            angle_loss = partial(qdarts_angle_loss,ref_state=ref_state,
                                 angle_penalty_str=angPenaltyStr)
            optimizer = torch.optim.Adam(input_gen.logit_recipe, 
                                         run_params['learning_rate'])
            angles_optimizer = torch.optim.Adam([input_gen.thetas],
                                                run_params['learning_rate'])

        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, Tmax)
        
        config = run_params | dict(
            num_qubits = num_qubits
        )

        run_name = (f'{run_params["entanglement_type"]}-{num_qubits}-qubits-'
                    f'{num_layers}-layers')
        if num_hidden_units > 0:
            run_name += '-hu'
        if noise_channel is not None:
            run_name += (f'{run_params["noise_model"]}-p-'
                         f'{run_params["noise_prob"]}')

        with wandb.init(project=PROJECT_NAME,config=config,name=run_name) as run:
            ket_0 = Statevector(num_qubits).to(cuda_device)
            q_n = math.factorial(num_qubits+1)/(num_qubits+3)**num_qubits
            id_penalty_str = 10**(math.ceil(-math.log10(q_n))-1)
            early_stop = False
            for i in range(N):
                optimizer.zero_grad()
                if search_type == 'RhoDARTS':
                    qs, angles, probs = model(softmax_temperature=softmax_temp)
                    loss_val, metrics = loss(qs, angles, probs, ref_state, i, 
                                         entropySchedule, entPenaltyStr, 
                                         angPenaltyStr)
                else:
                    angles,logits = input_gen()
                    qs, sampled_gates = qdarts(logits, angles, angles_optimizer,
                                num_iter, angle_loss,
                                softmax_temp)
                    probs = (logits/softmax_temp).softmax(dim=-1)
                    loss_val, metrics = qdarts_arch_loss(qs, probs, ref_state, 
                                                         i, entropySchedule, 
                                                         entPenaltyStr)
                
                # -----------Unused penalties for first layer identity --------#
                # Penalize for global identity in the first layer
                q = torch.empty(num_qubits,dtype=probs.dtype,device=cuda_device)
                for i in range(num_qubits):
                    q[i] = probs[0,i,0] + probs[0,i,3] + probs[0,i,4+i:].sum()
                # Penalize choosing gates that are equivalent to the identity
                uncond_id = (q-probs[0,:,0]).sum()
                cond_id = torch.zeros([num_qubits,num_qubits-1],dtype=probs.dtype,
                                  device=cuda_device)
                for i in range(1,num_qubits):
                    for j in range(i):
                        if j == 0:
                            cond_id[i,j] = 1-probs[0,j,1:3].sum()
                        else: # 0 < j < i
                            for k in range(j):
                                cond_id[i,j] += (1-cond_id[i-1,k])*probs[0,i-1,4+k]
                            cond_id[i,j] += probs[0,i-1,0]+probs[0,i-1,3]+probs[0,i-1,4+j:].sum()
                
                # Penalize small angles
                small_angle_penalty = torch.exp(-(angles/0.1)**2).sum()


                # net_loss = (loss_val 
                #             + id_penalty_str * q.prod() 
                #             + uncond_id 
                #             + cond_id.sum() 
                #             + small_angle_penalty)
                
                # net_loss.backward()

                wandb.log({
                    'loss': loss_val.item(),
                    # 'net_loss':net_loss.item(),
                    'fidelity': (fid:=metrics[0].item()),
                    'entropy': metrics[1].item(),
                    'id-prob':q.prod().item(),
                    'uncond-id':uncond_id.item(),
                    'cond-id':cond_id.sum().item(),
                    'small_angle_penalty': small_angle_penalty.item()
                    # 'angle penalty': metrics[2].item()
                })
                
                if fid >= 1.0-stop_threshold:
                    early_stop = True
                    break
                
                loss_val.backward()
                optimizer.step()
                scheduler.step()
            
            if search_type == 'RhoDARTS':
                qs, angles, probs = model(softmax_temperature=softmax_temp)
            else:
                angles, logits = input_gen()
                if not early_stop:
                    probs = (logits/softmax_temp).softmax(dim=-1)
                else:
                    probs = sampled_gates
            gates = probs.argmax(dim=-1)
            circuit_img = quantum_circuit_plot(gates, angles)
            wandb.log({
                'circuit_img': wandb.Image(circuit_img,
                                           caption='Found Quantum Circuit')
            })

        final_gates.append(gates.cpu().clone())
        final_angles.append(angles.cpu().clone())
        final_states.append(qs.cpu().clone())
    
    with wandb.init(project=PROJECT_NAME, job_type='upload-circuits',
                              name='uploader') as artifact_run:
        circuit_artifact = wandb.Artifact(
            (f'{run_params["search_type"]}-{run_params["entanglement_type"]}-'
            f'Circuits-'
            f'{"With" if run_params["use_hidden_units"] else "Without"}-Hidden-'
            'Units'
            + (f'-{run_params["noise_model"]}-p-{run_params["noise_prob"]:0.2f}' 
            if run_params['noise_model'] != 'None' else '')
            ),
            type='circuit-data',
            description=f'The cricuits found by {run_params["search_type"]} to '
            f'produce {run_params["entanglement_type"]} states, along with the '
            'corresponding states produced.',
            metadata={
                'max_qubits': run_params['max_qubits'],
                'entanglement_type': run_params['entanglement_type'],
            }
        )
        for num_qubits in range(2,run_params['max_qubits']+1):
            with circuit_artifact.new_file(f'circuit_{num_qubits}.pt', 
                                            mode='wb') as file:
                torch.save(final_gates[num_qubits-2], file)
            with circuit_artifact.new_file(f'angles_{num_qubits}.pt', 
                                            mode='wb') as file:
                torch.save(final_angles[num_qubits-2], file)
            with circuit_artifact.new_file(f'state_{num_qubits}.pt', 
                                            mode='wb') as file:
                torch.save(final_states[num_qubits-2], file)
            
        wandb.log_artifact(circuit_artifact)   
    


if __name__ == '__main__':
    import matplotlib
    
    import argparse
    import yaml
    
    matplotlib.use('agg')
    updateConfig('skipValidation',True)

    wandb.login()

    parser = argparse.ArgumentParser(description='Script to run the macro-'
                                     'search experiment for the max-cut '
                                     'problem.')
    
    parser.add_argument('--project-name', type=str, required=True,
                        help='WANDB project name')
    
    parser.add_argument('-c','--cuda-device', type=int, default=0, 
                        help='Cuda device id you want to run the job on')
    parser.add_argument('-n', '--noisy-expt', action='store_true', 
                        help='Whether to run the noise experiments')
    parser.add_argument('-m', '--multiprocessing', action='store_true',
                        help='Whether to use multiprocessing to parallelize the'
                        ' runs')
    parser.add_argument('-p', '--num-processes', type=int, default=1,
                        help='Number of processes in the pool')
    parser.add_argument('-r','--num-reps', type=int, default=1,
                        help='Number of times to repeat each experiment')
    parser.add_argument('-s','--stop-threshold', type=float, default=1e-5,
                        help='Early stopping threshold for fidelity')
    parser.add_argument('--ghz-layers', type=int, default=1,
                        help='Number of layers for the GHZ circuits')
    parser.add_argument('--w-layers', type=int, default=3,
                        help='Number of layers for the W circuits will be '
                        '(number of qubits - 1) multiplied by this factor')
    parser.add_argument('--skip-w', action='store_true', 
                        help='Set to skip the W state experiments')
    parser.add_argument('--cfg', type=str, default=None,
                        help='Path to yaml file containing the config dict')
    
    args = parser.parse_args()


    base_params = dict(
        max_qubits=6,
        num_iterations=10_000,
        num_angle_iter=10,
        learning_rate=0.01,
        CAS_Tmult=0.1,
        angle_penalty_str=0.01,
        use_scheduler=True,
        entropy_penalty_str=0.1,
        schedule_start_point=0.0,
        schedule_start_time=0,
        schedule_duration='half',
        num_oscillations=1,
        gumbel_temp=0.05,
        ghz_layers=args.ghz_layers,
        w_layerFactor=args.w_layers,
        stop_threshold=args.stop_threshold
    )
    if args.cfg is not None:
        with open(args.cfg, 'r') as f:
            yaml_dict = yaml.safe_load(f)
        for key in yaml_dict:
            if key not in base_params:
                print(f'Unknown key `{key}` found in `{args.cfg}`, ignoring')
            else:
                base_params[key] = yaml_dict[key]

    def worker(search_type:Literal['RhoDARTS','QDARTS'],
               entanglement_type:Literal['GHZ','W'],
               useHiddenUnits:bool, 
               noise_model:str='None', 
               noise_prob:float=0.0):
        run_experiment(base_params
                       | {
                           'search_type':search_type,
                           'entanglement_type': entanglement_type,
                           'use_hidden_units': useHiddenUnits,
                           'noise_model': noise_model,
                           'noise_prob': noise_prob
                        },
                        args.project_name,
                        args.cuda_device)
    
    if not args.noisy_expt:
        worker_args = product(['RhoDARTS','QDARTS'],
                              ['GHZ', 'W'] if not args.skip_w else ['GHZ'], 
                              [False,True])
        args_generator = (item for item in chain.from_iterable(repeat(p, args.num_reps) for p in worker_args))
        if not args.multiprocessing:
            for w_args in args_generator:
                worker(*w_args)
        else:
            with mp.Pool(min(args.num_processes,mp.cpu_count())) as pool:
                pool.starmap(worker, args_generator)
    else:
        worker_args = chain(
            product(['RhoDARTS'],['GHZ','W'], [False,True], 
                    ['BitPhaseFlip','Depolarizing'],
                    [0.01, 0.1, 0.25, 0.5]),
            product(['QDARTS'],['GHZ','W'], [False,True], 
                    ['BitPhaseFlip'],
                    [0.01, 0.1, 0.25, 0.5]),
        )
        if not args.multiprocessing:
            for w_args in product(worker_args):
                worker(*w_args[0])
        else:
            with mp.Pool(min(args.num_processes,mp.cpu_count())) as pool:
                pool.starmap(worker, worker_args)
