#!/usr/bin/env python
# coding: utf-8


import math
import re
import os

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from qiskit import QuantumCircuit
from qiskit.quantum_info import Operator

from data_utils.aae_dataset import MNIST_AAE_Dataset

def seed_everything(seed):
    import random

    import numpy as np
    import torch

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


def visual_compare(result_state, target_state):
    """
    visually compare two (batch of) state in a notebook

    result_state: (B, n_qubits**2)
    target_state: (B, n_qubits**2)
    """

    result_state = result_state.detach().clone().cpu().type(torch.float64)
    target_state = target_state.detach().clone().cpu().type(torch.float64)

    n_samples, image_size = target_state.shape
    image_size = math.floor(image_size**0.5)

    fig, axes = plt.subplots(n_samples, 2, figsize=(8, 16))
    for i in range(n_samples):
        axes[i][0].imshow(
            result_state[i].view(1, image_size, image_size).permute(1, 2, 0)
        )
        axes[i][1].imshow(
            target_state[i].view(1, image_size, image_size).permute(1, 2, 0)
        )
    axes[0][0].set_title("result state")
    axes[0][1].set_title("target state")
    plt.show()

def get_test_loaders(folder_path, n_qubits, n_samples_per_ds=1):
    file_names = []
    for root, dirs, files in os.walk(folder_path):
        for file in files:
            if f"{n_qubits}-qubits" in file: 
               file_names.append(file)
    
    distribution_files = {}

    for file_name in file_names:
        distribution_name = file_name.split(".pt")[-2]
        distribution_files[distribution_name] = folder_path +file_name


    # normed_data = np.zeros((n_samples*16,len(distribution_files)))
    test_loaders = {}
    for idx, name in enumerate(distribution_files.keys()):
        test_dataset = MNIST_AAE_Dataset(distribution_files[name])
        test_loader = DataLoader(test_dataset, shuffle=True, batch_size=n_samples_per_ds, num_workers=0, pin_memory=True)
        test_loaders[name] = test_loader

    return test_loaders


def to_openqasm(qml):
    qasm = qml.qtape.to_openqasm()
    lines = qasm.split("\n")

    updated_lines = []
    for line in lines:
        if "tensor" in line:
            line = re.sub(r"ry\(tensor\((.*?)\)\)", r"ry(\g<1>)", line)
            line = re.sub(
                r"ry\(tensor\((.*?)(?:, grad_fn=<[^<>]*>)?\)\)", r"ry(\1)", line
            )
        if "measure" not in line:
            updated_lines.append(line)
    new_qasm = "\n".join(updated_lines)
    return new_qasm


def reverse_qasm(qasm, n_qubits):
    lines = qasm.split("\n")
    # 提取ry指令和cx指令的行号和目标寄存器索引
    ry_lines = []
    cx_lines = []
    for i, line in enumerate(lines):
        if "ry" in line:
            ry_lines.append((i, line))
        elif "cx" in line:
            cx_lines.append((i, line))

    # 反转目标寄存器索引
    reversed_ry_lines = []
    for i, line in ry_lines:
        index_start = line.find("q[")
        index_end = line.find("]", index_start)
        index_str = line[index_start + 2 : index_end]
        index = int(index_str)
        reversed_index = n_qubits - 1 - index
        reversed_line = line[: index_start + 2] + str(reversed_index) + line[index_end:]
        reversed_ry_lines.append((i, reversed_line))

    # 反转cx指令的目标寄存器索引
    for i, line in cx_lines:
        parts = re.split(",| ", line)
        indices = [n_qubits - 1 - int(parts[1][2]), n_qubits - 1 - int(parts[2][2])]
        for j in range(1, len(parts)):
            new_cx_line = (
                parts[0]
                + " "
                + parts[1][:2]
                + str(indices[0])
                + parts[1][3:]
                + ", "
                + parts[2][:2]
                + str(indices[1])
                + parts[2][3:]
            )
            if j < len(parts) - 1:
                new_cx_line += ","
        cx_lines[cx_lines.index((i, line))] = (i, new_cx_line)

    # 替换原始字符串中的ry指令行和cx指令行
    for i, line in reversed_ry_lines:
        lines[i] = line
    for i, line in cx_lines:
        lines[i] = line

    new_qasm = "\n".join(lines)

    return new_qasm


def to_qiskit(qml, n_qubits):
    openqasm_code = to_openqasm(qml)
    reversed_qasm = reverse_qasm(openqasm_code, n_qubits)
    qc = QuantumCircuit.from_qasm_str(reversed_qasm)

    return qc


def get_matrix(qc):
    operator = Operator(qc)
    return operator.data


def replace(qml, new_angles, n_qubits):
    qasm = to_openqasm(qml)
    lines = qasm.split("\n")
    ry_count = 0

    for i in range(len(lines)):
        if "ry" in lines[i]:
            # 提取原始值并进行替换
            old_value = float(lines[i].split("ry(")[1].split(")")[0])
            new_value = new_angles[ry_count]
            lines[i] = lines[i].replace(str(old_value), str(new_value))
            ry_count += 1

    new_openqasm_code = "\n".join(lines)
    reversed_qasm = reverse_qasm(new_openqasm_code, n_qubits)

    qc = QuantumCircuit.from_qasm_str(reversed_qasm)

    return qc


def append_log(log_file_path, item):
    with open(log_file_path, "a") as f:
        f.write(f"{item}\n")


def resize_and_norm(images, n_qubits=None):
    """
    resize a batch of images to (B, C, new_size, new_size), then normalize it.
    new_size determine by n_qubits
    """

    if n_qubits is None or n_qubits <= 4:  # default to (B,C,4,4)
        if images.shape[-1] != 28:  # only 28, 28 -> 4, 4
            images = torchvision.transforms.functional.resize(images, [28, 28])
        images = F.avg_pool2d(images, kernel_size=(7, 7))
        # images shape: (batchsize, 1, 4, 4)
    else:
        new_size = math.floor((2**n_qubits) ** 0.5)  # max size n_qubits can support
        # simply resize use torchvision
        images = torchvision.transforms.functional.resize(images, [new_size, new_size])

    images = images.reshape(images.size(0), -1)

    norms = torch.norm(images, p=2, dim=1, keepdim=True)
    images = images / norms
    return images


def resize(images, n_qubits=None):
    """
    resize a batch of images to (B, C, new_size, new_size), then normalize it.
    new_size determine by n_qubits
    """

    if n_qubits is None or n_qubits <= 4:  # default to (B,C,4,4)
        if images.shape[-1] != 28:  # only 28, 28 -> 4, 4
            images = torchvision.transforms.functional.resize(images, [28, 28])
        images = F.avg_pool2d(images, kernel_size=(7, 7))
        # images shape: (batchsize, 1, 4, 4)
    else:
        new_size = math.floor((2**n_qubits) ** 0.5)  # max size n_qubits can support
        # simply resize use torchvision
        images = torchvision.transforms.functional.resize(images, [new_size, new_size])

    return images


def norm_image(images):
    images = images.reshape(images.size(0), -1)
    norms = torch.norm(images, p=2, dim=1, keepdim=True)
    images = images / norms
    return images


def add_noise(images, noise_factor, device):
    return (
        images * (1 - noise_factor)
        + torch.FloatTensor(size=images.shape).uniform_(-1, 1).to(device) * noise_factor
    )
