from abc import ABC

import torch
import gym
from gym import spaces
import pygame
from cgqn import GenerativeQueryNetwork

import strawberryfields as sf
from strawberryfields.ops import *
import numpy as np


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class QCVENV(gym.Env):
    def __init__(self, target_param=None, initial_param=None, r_dim=32, h_dim=48, z_dim=48, L=2,
                 n_views=10):

        self.a = 1+1j
        self.p = 0
        self.phis = np.linspace(0, np.pi, 300)
        scale = np.sqrt(sf.hbar)
        self.quad_axis = np.linspace(-6, 6, 100) * scale

        if target_param is None:
            # randomly generating the target parameters
            self.target_param = np.random.random(3)
            self.target_param[0] = self.target_param[0] * 2 - 1
            self.target_param[2] = self.target_param[2]*np.pi
        else:
            self.target_param = target_param

        self.init_param = initial_param

        self.n_views = n_views

        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(r_dim,), dtype=np.float32)
        # We have 12 actions, corresponding to "right", "up", "left", "down" in a hypercube
        self.action_space = spaces.Box(low=np.array([-1, 0]), high=np.array([1, np.pi]), shape=(2, ),
                                       dtype=np.float64)

        self.model = GenerativeQueryNetwork(x_dim=100,
                                            v_dim=1,
                                            r_dim=r_dim,
                                            h_dim=h_dim,
                                            z_dim=z_dim, L=L)
        try:
            import os
            current_directory = os.path.dirname(os.path.abspath(__file__))
            self.model.load_state_dict(torch.load(
                current_directory+'/models/' + str(r_dim)+"_"+str(h_dim)+"_"+str(z_dim) +
                '_cat_a'+str(self.a)+'_p'+str(self.p)+'_cpu', map_location='cpu'))
            self.model.to(device)
            self.model.eval()
            print("Total number of param in Model is ", sum(x.numel() for x in self.model.parameters()))
        except:
            print("NO Load")

        self.phis = np.linspace(0, np.pi, 300)

        d_target = self.target_param[0]
        s_target = self.target_param[1]
        r1_target = self.target_param[2]
        r2_target = self.target_param[3]
        self.target_cat = self.generate_cat_state(d_target, s_target, r1_target, r2_target)

    def calculate_fidelity(self, state1, state2):
        xvec = np.linspace(-15, 15, 401)
        W1 = state1.wigner(mode=0, xvec=xvec, pvec=xvec)
        W2 = state2.wigner(mode=0, xvec=xvec, pvec=xvec)
        return np.sum(W1 * W2 * 30 / 400 * 30 / 400) * 4 * np.pi

    def generate_cat_state(self, d, s, r1, r2):
        prog_cat = sf.Program(1)
        with prog_cat.context as q:
            sf.ops.Catstate(a=self.a, p=self.p) | q
            Rgate(r1) | q
            Sgate(s) | q
            Rgate(r2) | q
            Dgate(d) | q
        eng = sf.Engine("bosonic")
        cat = eng.run(prog_cat).state
        return cat

    def generate_cat_homodyne_prob(self, state, phi):
        return state.marginal(0, self.quad_axis, phi=phi)

    def _get_initial_obs(self):

        v = torch.FloatTensor(self.phis).to(device)
        v = v.unsqueeze(dim=0)
        v = v.unsqueeze(dim=2)

        batch_size, m, *_ = v.size()
        indices = list(range(0, m))
        np.random.shuffle(indices)
        representation_idx = indices[:self.n_views]
        context_v = v[:, representation_idx]
        chosen_phis = self.phis[representation_idx]

        intial_probs = []
        for j in range(0, self.n_views):
            intial_probs.append(self.generate_cat_homodyne_prob(self.initial_cat, chosen_phis[j]))

        intial_probs = np.expand_dims(np.array(intial_probs), axis=0)
        context_x = torch.FloatTensor(intial_probs).to(device)

        b, m, *x_dims = context_x.shape
        _, _, *v_dims = context_v.shape
        context_x = context_x.view((-1, *x_dims))
        context_v = context_v.view((-1, *v_dims))

        with torch.no_grad():
            phi = self.model.representation(context_x, context_v)
            _, *phi_dims = phi.shape
            phi = phi.view((b, m, *phi_dims))
            # sum over n_views to obtain representations
            neural_rep_agent = torch.mean(phi, dim=1)

        return neural_rep_agent.squeeze().cpu().numpy(), representation_idx

    def _get_target_obs(self, phis_basis):

        v = torch.FloatTensor(self.phis).to(device)
        v = v.unsqueeze(dim=0)
        v = v.unsqueeze(dim=2)
        context_v = v[:, phis_basis]
        chosen_phi = self.phis[phis_basis]
        cat_prob = []
        # 每次都要根据目标态和给定的观测量得到概率分布
        for j in range(0, self.n_views):
            cat_prob.append(self.generate_cat_homodyne_prob(self.target_cat, chosen_phi[j]))

        agent_probs = np.expand_dims(np.array(cat_prob), axis=0)
        context_x = torch.FloatTensor(agent_probs).to(device)

        context_x = context_x.float()
        context_v = context_v.float()

        b, m, *x_dims = context_x.shape
        _, _, *v_dims = context_v.shape

        x = context_x.view((-1, *x_dims))
        v = context_v.view((-1, *v_dims))

        with torch.no_grad():
            phi = self.model.representation(x, v)
            _, *phi_dims = phi.shape
            phi = phi.view((b, m, *phi_dims))
            # sum over n_views to obtain representations
            neural_rep_target = torch.mean(phi, dim=1)

        return neural_rep_target.squeeze().cpu().numpy()

    def _get_agent_obs(self, controls):
        """
        随机挑选30个observables
        """
        v = torch.FloatTensor(self.phis).to(device)
        v = v.unsqueeze(dim=0)
        v = v.unsqueeze(dim=2)

        batch_size, m, *_ = v.size()
        indices = list(range(0, m))
        np.random.shuffle(indices)
        representation_idx = indices[:self.n_views]
        context_v = v[:, representation_idx]
        chosen_phis = self.phis[representation_idx]

        d = controls[0]  # -1-1
        r = controls[1]  # 0-1

        next_state = generate_new_cat_state(self.current_state, d, r)

        agent_prob = []
        for j in range(0, self.n_views):
            agent_prob.append(self.generate_cat_homodyne_prob(next_state, chosen_phis[j]))

        agent_probs = np.expand_dims(np.array(agent_prob), axis=0)
        context_x = torch.FloatTensor(agent_probs).to(device)

        b, m, *x_dims = context_x.shape
        _, _, *v_dims = context_v.shape
        context_x = context_x.view((-1, *x_dims))
        context_v = context_v.view((-1, *v_dims))

        with torch.no_grad():
            phi = self.model.representation(context_x, context_v)
            _, *phi_dims = phi.shape
            phi = phi.view((b, m, *phi_dims))
            # sum over n_views to obtain representations
            neural_rep_agent = torch.mean(phi, dim=1)

        return neural_rep_agent.squeeze().cpu().numpy(), next_state, representation_idx

    def _get_info(self, agent_state):
        """
        :return: returns the distance of the target and agent hamiltonian cofficients
        """
        quantum_fidelity = self.calculate_fidelity(agent_state, self.target_cat)
        return {"quantum_fidelity": np.real(quantum_fidelity)}

    def reset(self, seed=None):
        # We need the following line to seed self.np_random
        super().reset(seed=seed)

        agent_small_change = np.random.uniform(-0.2, 0.2, 4)
        agent_small_change[1] = 0
        agent_small_change[2] = agent_small_change[2]
        agent_small_change[3] = agent_small_change[3]
        # intial param is fixed, and agent small change is varied for each time
        agent_param = self.init_param + agent_small_change

        self.initial_cat = self.generate_cat_state(agent_param[0], agent_param[1], agent_param[2], agent_param[3])
        initial_observation, basis_index = self._get_initial_obs()
        self.current_state = self.initial_cat
        return initial_observation

    def step(self, action):
        # initial cat for
        action[0] = np.clip(action[0], -1, 1)
        action[1] = np.clip(action[1], 0, np.pi)
        agent_observation, next_agent_state, basis_index = self._get_agent_obs(action)

        target_observation = self._get_target_obs(basis_index)
        normalized_euclidean_distance = (np.linalg.norm(agent_observation - target_observation)
                                         / np.sqrt(agent_observation.shape[0]))
        reward = -1 * normalized_euclidean_distance * 10  # data range (-1, 0)
        # 这个地方容易出错
        try:
            info = self._get_info(next_agent_state)
        except:
            info["quantum_fidelity"] = np.real(0)

        if reward == 0:
            done = 1
        else:
            done = 0

        self.current_state = next_agent_state

        return agent_observation, reward, done, info


def generate_new_cat_state(org_state, d, r):
    rs, Vs, ws = org_state.data
    prog_cat = sf.Program(1)
    with prog_cat.context as q:
        sf.ops.Bosonic(weights=ws, covs=Vs, means=rs) | q
        Dgate(d) | q
        Rgate(r) | q
    eng = sf.Engine("bosonic")
    cat = eng.run(prog_cat).state
    return cat


def test_env():
    env = QCVENV(target_param=np.array([0.1, 0,0.1, 0.1]), initial_param=np.array([0.6,0, 0.6, 0.6]))
    observation = env.reset(),
    agent_observation, reward, done, info = env.step([0,5, 0.5])

    print(agent_observation.shape)
    print(reward)
    print(info["quantum_fidelity"])
