import torch
import torch.nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

from typing import Literal

from qtorch import CTYPE,RTYPE
from qtorch.config import QTORCH_CONFIG,updateConfig
from qtorch.unitaries import copyStandardGatesTo,STANDARD_GATES_DICT
from qtorch.quantumstate.measurements import getZDistribution

from examples.qas.metrics import *
from examples.qas.penalty_scheduler import *
from examples.qas.qdarts import QDARTS
from examples.qas.rhodarts import RhoDARTS
from examples.qas.qas import QAS


def image_to_amplitude_encoding(img:torch.Tensor, num_qubits_per_dim:int):
    assert img.shape[0] == img.shape[1]
    # assert img.shape[2] == 1
    assert img.shape[0] <= 2**num_qubits_per_dim

    data = img.reshape(-1)

    if data.shape[0] < 4**num_qubits_per_dim:
        data = torch.cat((data, torch.zeros(4**num_qubits_per_dim-data.shape[0],dtype=data.dtype,device=data.device)))

    return (data/data.norm()).to(CTYPE)

def amplitude_encoding_to_image(amp:torch.Tensor, num_qubits_per_dim:int, epsilon:float=1e-12)->torch.Tensor:
    probs = torch.clamp(getZDistribution(amp),epsilon)
    return torch.sqrt(probs).view(2**num_qubits_per_dim,2**num_qubits_per_dim)

def image_to_frqi(img:torch.Tensor, num_qubits_per_dim:int)->torch.Tensor:
    assert img.shape[0] == img.shape[1]
    # assert img.shape[2] == 1
    assert img.shape[0] <= 2**num_qubits_per_dim

    data = torch.stack([torch.cos(torch.pi/2 * img.reshape(-1)), torch.sin(torch.pi/2 * img.reshape(-1))],dim=1).reshape(-1)

    return (data/data.norm()).to(CTYPE)

def frqi_to_image(frqi:torch.Tensor, num_qubits_per_dim:int, epsilon:float=1e-12)->torch.Tensor:
    assert frqi.shape[0] == 2**(2*num_qubits_per_dim + 1)
    probs = getZDistribution(frqi)
    marginal_probs = torch.clamp(torch.sum(probs.view(-1,2),dim=1),epsilon)
    return torch.stack([
            torch.arccos( torch.clamp(torch.sqrt( torch.clamp(probs[0::2]/marginal_probs,epsilon) ), max=1.0-epsilon) ).reshape([2**num_qubits_per_dim,2**num_qubits_per_dim]),
            torch.arcsin( torch.clamp(torch.sqrt( torch.clamp(probs[1::2]/marginal_probs,epsilon) ), max=1.0-epsilon) ).reshape([2**num_qubits_per_dim,2**num_qubits_per_dim]),
    ], dim=0) / (torch.pi/2), marginal_probs

def frqi_init_state(num_qubits_per_dim:int,device='cpu')->torch.Tensor:
    data = torch.stack([torch.ones(2**(2*num_qubits_per_dim), dtype=CTYPE,device=device),
                 torch.zeros(2**(2*num_qubits_per_dim), dtype=CTYPE,device=device)], dim=1).reshape(-1)
    return data/data.norm()

def get_metrics(qs:torch.Tensor,
                num_qubits_per_dim:int,
                angles:torch.Tensor, 
                probs:torch.Tensor,
                ref_state:torch.Tensor,
                ref_img:torch.Tensor,
                img_type:Literal['AMP','FRQI'],
                iteration:int,
                eval_type:Literal['Fidelity','MSE'],
                exploration_type:str,
                exploration_schedule:torch.Tensor, 
                exploration_penalty_str:float, 
                angle_penalty_str:float,
                marginal_prob_error_str:float)->torch.Tensor:
    fid = fidelity(qs, ref_state)
    angle_penalty = anglePenalty(angles)
    

    if img_type == 'AMP':
        images = amplitude_encoding_to_image(qs, num_qubits_per_dim)
        MSE = F.mse_loss(images,ref_img/ref_img.view(-1).norm())
        marginal_prob_error = torch.tensor(0.0)
    if img_type == 'FRQI':
        images, marginal_probs = frqi_to_image(qs, num_qubits_per_dim)
        MSE = F.mse_loss(images, ref_img.unsqueeze(0).expand(2,-1,-1))
        marginal_prob_error = F.l1_loss(marginal_probs, torch.ones_like(marginal_probs)/(2**(2*num_qubits_per_dim)))

    loss = angle_penalty*angle_penalty_str
    
    if eval_type == 'Fidelity':
        loss.add_(1.0-fid)
    elif eval_type == 'MSE':
        loss.add_(MSE + marginal_prob_error_str*marginal_prob_error)
    else:
        raise NotImplementedError()

    if exploration_type != 'None':
        exploration_score = purity(qs) if exploration_type == 'Purity' else normalized_mean_entropy(probs)
        loss.add_((-1 if exploration_type=='Entropy' else 1)*exploration_schedule[iteration]*exploration_score*exploration_penalty_str)

    return loss, fid, MSE, exploration_score, angle_penalty, marginal_prob_error, images

mnist_train:torchvision.datasets.MNIST = None
def load_dataset(img_dim:int=16):
    global mnist_train
    transform = transforms.Compose([
        transforms.Resize((img_dim,img_dim)),
        transforms.ToTensor()
    ])
    mnist_train = torchvision.datasets.MNIST(root='./data', train=True, transform=transform)

def get_ref_image(digit:int, index:int)->torch.Tensor:
    global mnist_train
    indices = [i for i, (img, label) in enumerate(mnist_train) if label == digit]
    return mnist_train[indices[index]][0].squeeze(0)

if __name__ == '__main__':
    import wandb
    import math
    from itertools import product
    wandb.login()
    PROJECT_NAME = 'IMAGE INIT TEST'
    
    dev = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    cuda_gate_dict = copyStandardGatesTo(dev,['I','X','Y','Z']) if torch.cuda.is_available() else STANDARD_GATES_DICT
    paulis = torch.stack([cuda_gate_dict[key] for key in cuda_gate_dict],dim=0)

    updateConfig('skipValidation',True)

    def run_experiment(config=None):
        global ref_img
        global img_log_freq
        global dev
        with wandb.init(config=config) as run:
            config = wandb.config

            img_type = config.img_type
            model_type = config.model_type
            num_qubits_per_dim = config.num_qubits_per_dim
            
            if img_type == 'AMP':
                n = 2*num_qubits_per_dim
                psi0 = torch.zeros(2**n,dtype=CTYPE,device=dev)
                psi0[0] = 1.0
                ref_state = image_to_amplitude_encoding(ref_img,num_qubits_per_dim)
            elif img_type == 'FRQI':
                n = 2*num_qubits_per_dim + 1
                psi0 = frqi_init_state(num_qubits_per_dim,device=dev)
                ref_state = image_to_frqi(ref_img,num_qubits_per_dim)
            else:
                raise NotImplementedError()
            
            m, K = config.num_layers, config.num_hidden_units
            max_iter = config.max_iterations
            eval_type = config.eval_type
            expl_type = config.exploration_type
            expl_scheduler = config.exploration_scheduler
            num_cycles = config.num_exploration_cycles
            softmax_temperature = config.softmax_temperature

            expl_penalty_str = config.exploration_penalty_strength
            angle_penalty_str = config.angle_penalty_strength
            frqi_marginal_penalty_str = config.frqi_marginal_penalty_str
            lr = config.learning_rate
            T_max = math.floor(config.CAS_T_max * max_iter)
            tau = config.gumbel_temp

            t = torch.arange(max_iter,device=dev)/(max_iter-1)
            
            if expl_scheduler == 'Linear':
                expl_schedule = linear_scheduler(t,1.0,-1.0)
            elif expl_scheduler == 'Oscillating':
                expl_schedule = oscillating_scheduler(t,1.0,-1.0,num_cycles)
            else:
                raise NotImplementedError()

            if model_type == 'QDARTS':
                model = QAS(n, m, K, QDARTS, tau=tau, gumbel_hard=True, pauli_matrices=paulis).to(dev)
            elif model_type == 'RhoDARTS':
                model = QAS(n, m, K, RhoDARTS, pauli_matrices=paulis).to(dev)
            else: 
                raise NotImplementedError()
            
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max)

            wandb.log({'ref_image': wandb.Image(ref_img if img_type == 'FRQI' else ref_img/ref_img.view(-1).norm(), 
                                                caption='Reference Image')},commit=False)

            for i in range(max_iter):
                optimizer.zero_grad()
                qs, angles, logits = model(softmax_temperature,psi0=psi0)
                probs = torch.softmax(logits/softmax_temperature,dim=-1)
                optimizer.zero_grad()
                qs, angles, logits = model(softmax_temperature,psi0=psi0)
                probs = torch.softmax(logits/softmax_temperature,dim=-1)

                l, fid, mse, expl_score, angle_score, marginal_prob_error, images = get_metrics(
                    qs, num_qubits_per_dim, angles, probs, ref_state, ref_img,
                    img_type, i, eval_type, expl_type, expl_schedule,
                    expl_penalty_str, angle_penalty_str, frqi_marginal_penalty_str
                )   

                if img_type == 'AMP':
                    log_img = wandb.Image(images,
                                          caption='Amplitude Encoded Image')
                if img_type == 'FRQI':
                    log_img = wandb.Image(torchvision.utils.make_grid(images.unsqueeze(-1).permute(0,3,1,2)),
                                          caption='Left: Cos Image, Right: Sin Image')
                
                log_dict = {
                    'fidelity':fid.item(),
                    'mse':mse.item(),
                    expl_type:expl_score.item(),
                    'angle_penalty':angle_score.item(),
                    'marginal_dist_penalty':marginal_prob_error.item(),
                    'loss':l.item(),
                    'logits_max': logits.max().item(),
                    'logits_min': logits.min().item(),
                }
                if i % img_log_freq == 0 or i == max_iter - 1:
                    log_dict['images'] = log_img
                wandb.log(log_dict)
                
                l.backward()

                optimizer.step()
                scheduler.step()

    
    num_qubits_per_dim = 4
    num_layers = 15
    num_hidden_units = 0
    max_iterations = 2_500
    img_log_freq = 100
    exploration_type = 'Entropy'
    num_cylces=1
    softmax_temp = 1.0
    exploration_penalty_str = [0.1, 0.5]
    angle_penalty_str = [0.01, 0.10]
    frqi_marginal_penalty_str = [0.1, 0.5]
    learning_rate = [1e-3, 1e-2]
    CAS_T_max = [0.25, 0.50, 1.0]
    gumbel_temp = 0.5

    load_dataset(2**num_qubits_per_dim)
    ref_img = get_ref_image(4,0).to(device=dev)

    NUM_RUNS = 10

    
    for img_type in ['AMP']:
        torch.cuda.empty_cache()
        parameters = dict(
            # Experiment Configuration
            img_type = {'value': img_type},
            model_type = {'value': 'RhoDARTS'},
            num_qubits_per_dim = {'value': num_qubits_per_dim},
            num_layers = {'value': num_layers},
            num_hidden_units = {'value': num_hidden_units},
            max_iterations = {'value': max_iterations},
            exploration_type = {'value': exploration_type},
            softmax_temperature = {'value': softmax_temp},
            num_exploration_cycles = {'value': num_cylces},
            # Hyperparameters
            eval_type = {'values': ['MSE', 'Fidelity']},
            exploration_scheduler = {'values': ['Linear', 'Oscillating']},
            exploration_penalty_strength = {
                # 'value': exploration_penalty_str
                'distribution': 'uniform',
                'min': exploration_penalty_str[0],
                'max': exploration_penalty_str[1]
            },
            angle_penalty_strength = {
                # 'value': angle_penalty_str
                'distribution': 'uniform',
                'min': angle_penalty_str[0],
                'max': angle_penalty_str[1]
            },
            frqi_marginal_penalty_str = {
                # 'value': frqi_marginal_penalty_str
                'distribution': 'uniform',
                'min': frqi_marginal_penalty_str[0],
                'max': frqi_marginal_penalty_str[1]
            } if img_type == 'FRQI' else {'value': 0.0},
            learning_rate = {
                # 'value':learning_rate
                'distribution': 'uniform',
                'min': learning_rate[0],
                'max': learning_rate[1]
            },
            CAS_T_max = {
                # 'value': CAS_T_max
                'values': CAS_T_max
            },
            # Unused because QDARTS
            gumbel_temp = {'value': gumbel_temp}
        )

        sweep_name = f'{img_type} HyperParameter Sweep'

        sweep_config = dict(
            name=sweep_name,
            method='random',
            metric={'name':'loss',
                    'goal':'minimize'},
            parameters=parameters
        )
        sweep_id = wandb.sweep(sweep_config,project=PROJECT_NAME)
        wandb.agent(sweep_id,run_experiment,count=NUM_RUNS)
        wandb.alert('Sweep Done', f'{img_type} sweep finished')
        wandb.teardown()

