import numpy as np
import copy
from qiskit.quantum_info import random_statevector
from functools import reduce
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

class QuantumDataGenarator:

    @staticmethod
    def data_generation(n = 2, _dtype=np.complex128):
        '''
        Synthetic data from Guangxi Li, Zhixin Song, and Xin Wang. Vsql: Variational shadow quantum learning for414
        classification. Proceedings of the AAAI Conference on Artificial Intelligence, 35(9):8357–8365,415
        May 2021, adjusted for higher qubits.
        '''
        u = np.random.rand()
        v = np.random.rand()

        base = np.zeros((2 ** n,),dtype=_dtype)

        zero_zero = copy.deepcopy(base)
        zero_zero[0] = 1

        one_zero = copy.deepcopy(base)
        one_zero[2 ** (n - 1)] = 1

        zero_one = copy.deepcopy(base)
        zero_one[1] = 1


        sai_u = np.sqrt(1 - u ** 2) * zero_zero + u * one_zero
        sai_v_plus = np.sqrt(1 - v ** 2) * zero_one + v * one_zero
        sai_v_min = -1 * np.sqrt(1 - v ** 2) * zero_one + v * one_zero

        rand_par = np.random.rand()
        if rand_par < 1/3 :
          return (sai_u, -1)
        elif rand_par > 2/3 :
          return (sai_v_min, 1)
        return (sai_v_plus, 1)

    @staticmethod
    def GHZ(n,_dtype=np.complex128):
        zero = np.array([1, 0], dtype = _dtype)
        one = np.array([0, 1], dtype = _dtype)
        zeros = [zero] * n
        ones = [one] * n
        p1 = reduce(np.kron, zeros)
        p2 = reduce(np.kron, ones)

        state = 1/np.sqrt(2) * (p1 + p2)

        return state

    @staticmethod
    def randomstate(n, seed=None, _dtype=np.complex128):
        dim = 2**n

        s=[44,56,67,78,890]
        A = random_statevector(int(np.sqrt(dim)),seed=54)
        B = random_statevector(int(np.sqrt(dim)),seed=32)
        St = np.kron(A,B)
        return St

    @staticmethod
    def data_generationGHZ(n):
        '''
        Generate GHZ Dataset
        '''
        rand_par = np.random.rand()
        if rand_par < 1/2 :
            return (QuantumDataGenarator.GHZ(n), -1)
        return (QuantumDataGenarator.randomstate(n), 1)

    @staticmethod
    def genDataset(qubits, samples =1000):
        np.random.seed(54)
        data=[]
        labels=[]
        for i in range(samples):
            rho, label = QuantumDataGenarator.data_generation(n=qubits)
            data.append(rho)
            labels.append(label)
        return data , labels
    
    @staticmethod
    def gen(qubits,samples = 1000, threshold = 0.5, seed = 54):
        np.random.seed(seed)
        stats = {-1 : 0, 1 : 0}
        data_set = {}
        ghz = QuantumDataGenarator.GHZ(qubits)
        separable_state =QuantumDataGenarator.randomstate(qubits)
        Samples={}

        Samples[-1] = ghz
        Samples[1] = separable_state
        #Samples[-1] = np.matrix(mat.data)

        for i in range(samples):

            magic_num = np.random.uniform(0,1)
            if magic_num > threshold:
                data_set[i] =1
            else:
                data_set[i] = -1

        return  data_set, Samples
        
    @staticmethod 
    def mnist_binary_classification(nqubits, samples =1000, mode='real'):
        
        mnist = fetch_openml('mnist_784', version=1,data_home='./mnist_cache', as_frame=False)
        data = mnist['data']
        labels = mnist['target'].astype(int)

        # Filter for digits 3 and 6
        filter_idx = (labels == 3) | (labels == 6)
        data_filtered = data[filter_idx]
        labels_filtered = labels[filter_idx]

        scaler = StandardScaler()
        data_normalized = scaler.fit_transform(data_filtered)

        pca = PCA(n_components=2**nqubits)
        data_reduced = pca.fit_transform(data_normalized)

        # If the number of PCA components is less than nqubits, pad with zeros
        # if data_reduced.shape[1] < 2 ** nqubits:
        #     padding = np.zeros((data_reduced.shape[0], 2**nqubits - data_reduced.shape[1]))  # Padding with zeros
        #     data_reduced = np.hstack((data_reduced, padding))  # Add padding to the reduced data


        # Visualize the PCA-reduced data
        # plt.figure(figsize=(8, 6))
        # for digit in [3, 6]:
        #     mask = labels_filtered == digit
        #     plt.scatter(data_reduced[mask, 0], data_reduced[mask, 1], label=f'Digit {digit}', alpha=0.6)
        # plt.xlabel('Principal Component 3')
        # plt.ylabel('Principal Component 6')
        # plt.title('PCA Visualization of Digits 3 and 6')
        # plt.legend()
        # plt.grid(True)
        # plt.show(block=True)

        labels_filtered = [-1 if x == 3 else 1 for x in labels_filtered]
        filtered_labels = []
        qdata=[]
        for i in range(samples):
            # Normalize the input vector
            sample = data_reduced[i]
            if mode == 'real':
            #for only real
                norm = np.linalg.norm(sample)
                if norm == 0:
                    continue
                qdata.append(sample/norm)
                filtered_labels.append(labels_filtered[i])

            elif mode == "complex":
                sample_imag = np.sin(sample * np.pi)
                sample_complex = sample + 1j * sample_imag
                norm = np.linalg.norm(sample_complex)
                if norm == 0:
                    continue
                qdata.append(sample_complex / norm)
                filtered_labels.append(labels_filtered[i])

        return qdata, filtered_labels
        