import torch
import torchvision.transforms as transforms
from PIL import Image
import wandb

from typing import Literal,Any
from itertools import product,repeat,chain
from functools import partial
from copy import deepcopy
import multiprocessing as mp
import os

from qtorch.config import updateConfig

from ..qas import QAS, RhoDARTS, QDARTS, InputGen
from ..max_cut.plots import quantum_circuit_plot

from ..entangled_states.loss import loss, qdarts_angle_loss, qdarts_arch_loss
from ..entangled_states.utils import getEntropySchedule
from .utils import (get_first_mnist_digits, get_amplitude_encoding, 
                    get_image_from_amplitude_encoding, make_ref_images)

def run_experiment(run_params:dict[str,Any],
                   ref_images:torch.Tensor,
                   img_id:int,
                   PROJECT_NAME:str,
                   cuda_device:int=0,
                   )->None:
    dev = f'cuda:{cuda_device}'
    ref_images = ref_images.to(dev)
    search_type = run_params['search_type']
    img_size = run_params['image_size']
    num_qubits = 2*(img_size-1).bit_length()
    layer_factor = run_params['layer_factor']
    N = run_params['num_iterations']
    if search_type == 'QDARTS':
        num_iter = run_params['num_angle_iter']
    had_init = run_params['hadamard_init']
    stop_threshold = run_params['stop_threshold']
    Tmax = int(run_params['CAS_Tmult']*N)
    
    entropySchedule = getEntropySchedule(
        N,
        run_params['schedule_start_point'],
        run_params['schedule_start_time'],
        run_params['schedule_duration'],
        run_params['num_oscillations']
    )
    entPenaltyStr = run_params['entropy_penalty_str']
    angPenaltyStr = run_params['angle_penalty_str']
    
    if run_params['use_hidden_units']:
        softmax_temp = 10.0
    else:
        softmax_temp = 1.0

    final_gates = []
    final_angles = []
    final_states = []
    
    noise_channel = None
    
    if run_params['use_hidden_units']:
        num_hidden_units = 2*(num_qubits+3)
    else:
        num_hidden_units = 0
    
    noise_model = None

    ref_image = ref_images[img_id]
    ref_state = get_amplitude_encoding(ref_image.view(1,img_size,img_size))[0]
    ref_norm = ref_image.view(-1).norm(p=2)
    
    if had_init:
        psi0 = torch.ones(2**num_qubits,dtype=torch.complex64)/(2**(num_qubits/2))
    else:
        psi0 = None

    
    if search_type == 'RhoDARTS':
        model = QAS(num_qubits, layer_factor*num_qubits,
                            num_hidden_units, RhoDARTS,
                            noise_model = noise_model,
                            psi0=psi0
                            ).to(device=dev)
        optimizer = torch.optim.Adam(model.parameters(),
                                    run_params['learning_rate'])
    elif search_type == 'QDARTS':
        input_gen = InputGen(num_qubits, layer_factor*num_qubits, num_hidden_units).to(device=dev)
        qdarts = QDARTS(num_qubits, layer_factor*num_qubits, run_params['gumbel_temp'],
                        psi0=psi0,
                        noise_model=noise_model).to(device=dev)
        angle_loss = partial(qdarts_angle_loss,ref_state=ref_state,
                                angle_penalty_str=angPenaltyStr)
        optimizer = torch.optim.Adam(input_gen.logit_recipe, 
                                        run_params['learning_rate'])
        angles_optimizer = torch.optim.Adam([input_gen.thetas],
                                            run_params['learning_rate'])

    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, Tmax,)
    
    config = run_params | dict(
        num_qubits = num_qubits
    )

    run_name = f'img-{img_id}({num_qubits} qubit)'
    if had_init:
        run_name += '-had'
    if search_type == 'QDARTS':
        run_name += '-q'
    if num_hidden_units > 0:
        run_name += '-hu'
    
    
    with wandb.init(project=PROJECT_NAME,config=config,name=run_name) as run:
        run.define_metric('loss',step_metric='run_frac')
        run.define_metric('fidelity',step_metric='run_frac')
        run.define_metric('entropy',step_metric='run_frac')
        run.define_metric('angle_penalty',step_metric='run_frac')
        run.define_metric('img-mse',step_metric='run_frac')
        wandb.log(
            {
                'reference_img': wandb.Image(
                    transforms.ToPILImage()(ref_image.unsqueeze(0)),
                    caption='Reference Image'),
                'img_id': img_id
            }
        )
        for i in range(N):
            optimizer.zero_grad()
            if search_type == 'RhoDARTS':
                qs, angles, probs = model(softmax_temperature=softmax_temp)
                loss_val, metrics = loss(qs, angles, probs, ref_state, i, 
                                        entropySchedule, entPenaltyStr, 
                                        angPenaltyStr)
            else:
                angles,logits = input_gen()
                qs, sampled_gates = qdarts(logits, angles, angles_optimizer,
                            num_iter, angle_loss,
                            softmax_temp)
                probs = (logits/softmax_temp).softmax(dim=-1)
                loss_val, metrics = qdarts_arch_loss(qs, probs, ref_state, 
                                                        i, entropySchedule, 
                                                        entPenaltyStr)
            
            rec_image = get_image_from_amplitude_encoding(qs[0])*ref_norm
            img_mse = torch.nn.functional.mse_loss(rec_image, ref_image)
            
            loss_val.backward()
            optimizer.step()
            scheduler.step()

            wandb.log({
                'run_frac': (i+1)/N,
                'loss': loss_val.item(),
                'fidelity': metrics[0].item(),
                'entropy': metrics[1].item(),
                'angle penalty': metrics[2].item() if search_type == 'RhoDARTS' else float('nan'),
                'img-mse': img_mse.item(),
                'recovered_img': wandb.Image(
                    transforms.ToPILImage()(rec_image.unsqueeze(0)),
                    caption='Recovered Image')
            })

            if stop_threshold is not None:
                if 1.0 - metrics[0] <= stop_threshold:
                    break
        
        if search_type == 'RhoDARTS':
            qs, angles, probs = model(softmax_temperature=softmax_temp)
        else:
            angles, logits = input_gen()
            probs = (logits/softmax_temp).softmax(dim=-1)
        gates = probs.argmax(dim=-1)
        circuit_img = quantum_circuit_plot(gates, angles)
        
        wandb.log({
            'circuit_img': wandb.Image(circuit_img,
                                        caption='Found Quantum Circuit'),
        })
        
        final_gates  = gates.cpu().clone()
        final_angles = angles.cpu().clone()
        final_state = qs.cpu().clone()
        final_image = rec_image.cpu().clone()

        artifact_name = run_name.replace('(','-').replace(')','').replace(' ','-')
    
        circuit_artifact = wandb.Artifact(
            artifact_name,
            type='circuit-data',
            description=f'The cricuits found by {run_params["search_type"]} to '
            f'produce amplitude encoded image states, along with the '
            'corresponding states produced.',
        )
        # Save reference images and states
        with circuit_artifact.new_file(f'ref_images.pt', mode='wb') as file:
            torch.save(ref_image, file)
        with circuit_artifact.new_file(f'ref_states.pt', mode='wb') as file:
            torch.save(ref_state, file)
        with circuit_artifact.new_file(f'recovered_images.pt', mode='wb') as file:
            torch.save(final_image, file)
        
        # Save circuits, angles and final states
        with circuit_artifact.new_file(f'img_{img_id:01d}_circuit.pt', 
                                        mode='wb') as file:
            torch.save(final_gates, file)
        with circuit_artifact.new_file(f'img_{img_id:01d}_angles.pt', 
                                        mode='wb') as file:
            torch.save(final_angles, file)
        with circuit_artifact.new_file(f'img_{img_id:01d}_state.pt', 
                                        mode='wb') as file:
            torch.save(final_state, file)
            
        wandb.log_artifact(circuit_artifact)

if __name__ == '__main__':
    import matplotlib
    
    import argparse
    import yaml
    
    matplotlib.use('agg')
    updateConfig('skipValidation',True)

    wandb.login()

    parser = argparse.ArgumentParser(description='Script to run the macro-'
                                     'search experiment for the image '
                                     'initialization problem.')
    
    parser.add_argument('--project-name', type=str, required=True,
                        help='WANDB project name')
    
    parser.add_argument('-c','--cuda-device', type=int, default=0, 
                        help='Cuda device id you want to run the job on')
    parser.add_argument('-m', '--multiprocessing', action='store_true',
                        help='Whether to use multiprocessing to parallelize the'
                        ' runs')
    parser.add_argument('-p', '--num-processes', type=int, default=1,
                        help='Number of processes in the pool')
    parser.add_argument('--data-path', type=str, default='./data/', 
                        help='The file containing the reference images')
    parser.add_argument('--mnist-path', type=str, default='./data/',
                        help='Where to load the MNIST dataset from')
    parser.add_argument('--image-size', '-i', type=int, default=16,
                        help='Height of the resized images in pixels')
    parser.add_argument('--num-repeats', type=int, default=3,
                        help='Number of repeated experiments')
    parser.add_argument('-l','--layer-factor',type=int,default=3,
                        help='Factor to multiply the number of qubits with to get the layer count')
    parser.add_argument('-H', '--hadamard', action='store_true', 
                        help='Whether to use Hadamard initial state')
    parser.add_argument('-s','--stop-threshold',type=float,default=None,
                        help='Early stopping threshold')
    parser.add_argument('--cfg', type=str, default=None,
                        help='Path to yaml file for config dictionary')
    
    
    args = parser.parse_args()

    base_params = dict(
        search_type='RhoDARTS',
        
        # Swept hyperparameters
        num_iterations=2000,
        learning_rate=0.05,
        CAS_Tmult=0.1,
        use_hidden_units=True,

        # Fixed hyperparameters
        angle_penalty_str=0.01,
        use_scheduler=True,
        entropy_penalty_str=0.1,
        schedule_start_point=0.0,
        schedule_start_time=0,
        schedule_duration='half',
        num_oscillations=1,

        # QDARTS Specific
        num_angle_iter=10,
        gumbel_temp=0.05,

        # Experiment Specific
        layer_factor=None,
        image_size=None,
        hadamard_init=None,
        stop_threshold=None,
        search_profile=None,
    )
    
    expt_params = [
        {'layer_factor': args.layer_factor,
         'image_size': args.image_size,
         'hadamard_init': args.hadamard
        }
    ]

    base_params['stop_threshold'] = args.stop_threshold

    def updateDict(base:dict[str,Any], new:dict[str,Any], name:str, 
                   force:bool=False)->None:
        for key in new:
            if (key in base) or force:
                base[key] = new[key]
            else:
                print(f'key `{key}` not recognized in `{name}`, skipping')

    if args.cfg is not None:
        with open(args.cfg, 'r') as f:
            yml_dict = yaml.safe_load(f)
            if (k:='hyperparams') in yml_dict:
                updateDict(base_params, yml_dict[k], k)
            if (k:='rhodarts_params') in yml_dict:
                assert isinstance(yml_dict[k], list)
                rhoDARTS_params = yml_dict[k]
            if (k:='qdarts_params') in yml_dict:
                assert isinstance(yml_dict[k], list)
                qDARTS_params = yml_dict[k]
            if (k:='expt_params') in yml_dict:
                assert isinstance(yml_dict[k], list)
                expt_params = yml_dict[k]
    ref_images = {}
    img_paths = {}
    for expt in expt_params:
        imgsz = expt['image_size']
        if imgsz not in ref_images:
            ref_images[imgsz], img_paths[imgsz] = make_ref_images(args.data_path, args.mnist_path, imgsz)

    def worker(search_type:Literal['RhoDARTS','QDARTS'],
               expt_params:dict[str,Any],
               img_id):
        if 'search_profile' not in expt_params:
            raise KeyError(f'key `search_profile` not found in `expt_params`')
        search_params = (rhoDARTS_params if search_type == 'RhoDARTS' else qDARTS_params)
        sp = expt_params['search_profile']
        if sp < 0 or sp > len(search_params):
            raise ValueError(f'invalid search_profile ({sp}) for search type {search_type}')
        
        search_params = search_params[sp]
        run_params = deepcopy(base_params)
        updateDict(run_params, expt_params, 'expt_params')
        updateDict(run_params, search_params, 'search_params')
        
        run_experiment(run_params,
                       ref_images[expt_params['image_size']],
                       img_id,
                       args.project_name,
                       args.cuda_device)
    
    worker_args = product(['RhoDARTS','QDARTS'],
                          expt_params,
                          range(10))
    
    args_generator = (item for item in chain.from_iterable(repeat(p, args.num_repeats) for p in worker_args))
    if not args.multiprocessing:
        # for i in range(3):
        #     for w_args in worker_args:
        #         worker(*w_args)
        for w_args in args_generator:
            worker(*w_args)

    else:
        with mp.Pool(min(args.num_processes,mp.cpu_count())) as pool:
            pool.starmap(worker, args_generator)