import pytest
import numpy as np
from project_qsl.utils import partial_trace


class TestPartialTrace:
    def test_partial_trace(self):
        # bell state
        bell_dm = np.asarray(
            [
                [0.5, 0.0, 0.0, 0.5],
                [0.0, 0.0, 0.0, 0.0],
                [0.0, 0.0, 0.0, 0.0],
                [0.5, 0.0, 0.0, 0.5]
            ]
        )
        mixed_dm = 0.5*np.eye(2)

        # partial trace on 0 qubits returns the original dm
        rbell_dm = partial_trace(bell_dm, [])
        np.testing.assert_array_almost_equal(rbell_dm, bell_dm)

        # partial trace on last qubit gives mixed state
        rbell_dm = partial_trace(bell_dm, [1])
        np.testing.assert_array_almost_equal(rbell_dm, mixed_dm)

        # trace over all qubits gives 1.0
        rrbell_dm = partial_trace(bell_dm, [0, 1])
        np.testing.assert_array_almost_equal(rrbell_dm, np.array([[1.0]]))

    def test_ghz_trace(self):
        # ghz state
        ghz_dm = np.zeros([8, 8])
        ghz_dm[0, 0] = 0.5
        ghz_dm[0, -1] = 0.5
        ghz_dm[-1, 0] = 0.5
        ghz_dm[-1, -1] = 0.5

        mixed_dm = 0.5*np.eye(2)
        mixed_dm2 = np.zeros([4, 4])
        mixed_dm2[0, 0] = 0.5
        mixed_dm2[-1, -1] = 0.5

        # partial trace on last qubit
        rghz_dm = partial_trace(ghz_dm, [2])
        np.testing.assert_array_almost_equal(rghz_dm, mixed_dm2)
        
        # partial trace on last 2 qubit
        rghz_dm = partial_trace(ghz_dm, [1, 2])
        np.testing.assert_array_almost_equal(rghz_dm, mixed_dm)

        # partial trace on all qubits
        rghz_dm = partial_trace(ghz_dm, [0, 1, 2])
        np.testing.assert_array_almost_equal(rghz_dm, np.array([[1.0]]))
        
    def test_w_trace(self):
        # w state
        w_state = np.zeros([8])
        w_state[1] = 1.0/np.sqrt(3)
        w_state[2] = 1.0/np.sqrt(3)
        w_state[4] = 1.0/np.sqrt(3)
        w_dm = np.outer(w_state, w_state)

        w_rdm1 = (1.0/3)*np.diag([1.0, 1.0, 1.0, 0.0])
        w_rdm1[1, 2] = 1.0/3
        w_rdm1[2, 1] = 1.0/3
        w_rdm2 = np.diag([2.0/3, 1.0/3])

        # partial trace on last qubit
        rw_dm = partial_trace(w_dm, [2])
        np.testing.assert_array_almost_equal(rw_dm, w_rdm1)

        # partial trace on last 2 qubit
        rw_dm = partial_trace(w_dm, [1, 2])
        np.testing.assert_array_almost_equal(rw_dm, w_rdm2)

        # partial trace over all qubits
        rw_dm = partial_trace(w_dm, [0, 1, 2])
        np.testing.assert_array_almost_equal(rw_dm, np.array([[1.0]]))

    def test_general(self):
        # |\psi\rangle = 1/sqrt(2)*(|00\rangle+|01\rangle)
        # NOTE: the endian of QCompute requires index of matrix used here 10 -> |01\rangle
        dm = 0.5*np.asarray(
            [
                [1.0, 0.0, 1.0, 0.0],
                [0.0, 0.0, 0.0, 0.0],
                [1.0, 0.0, 1.0, 0.0],
                [0.0, 0.0, 0.0, 0.0]
            ]
        )
        rdm2 = np.asarray([[1.0, 0.0], [0.0, 0.0]])

        rg_dm = partial_trace(dm, [1])
        np.testing.assert_array_almost_equal(rg_dm, rdm2)
