from abc import ABC

import torch
import gym
from gym import spaces
import pygame
from cgqn import GenerativeQueryNetwork

import strawberryfields as sf
from strawberryfields.ops import *
import numpy as np


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class QCVENV(gym.Env):
    def __init__(self, target_param=None, initial_param=None,
                 action_step=0.01,
                 r_dim=24, h_dim=32, z_dim=32, L=2,
                 n_views=10):

        self.phis = np.linspace(0, np.pi, 300)
        scale = np.sqrt(sf.hbar)
        self.quad_axis = np.linspace(-6, 6, 100) * scale

        if target_param is None:
            # randomly generating the target parameters
            self.target_param = np.random.random(2)
            self.target_param = self.target_param * 4 - 2
        else:
            self.target_param = target_param

        self.init_param = initial_param
        self.n_views = n_views

        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(r_dim,),
                                            dtype=np.float32)

        self.action_step = action_step
        self.action_space = spaces.MultiDiscrete(np.array([3, 3]))

        self._action_to_direction = {
            0: 1,
            1: -1,
            2: 0
        }

        self.model = GenerativeQueryNetwork(x_dim=100,
                                            v_dim=1,
                                            r_dim=r_dim,
                                            h_dim=h_dim,
                                            z_dim=z_dim, L=L)
        try:
            import os
            current_directory = os.path.dirname(os.path.abspath(__file__))
            self.model.load_state_dict(torch.load(
                current_directory+'/models/' + str(r_dim)+"_"+str(h_dim)+"_"+str(z_dim) +
                '_cat_alpha_minus2_plus2_cpu', map_location='cpu'))
            self.model.to(device)
            self.model.eval()
            print("Total number of param in Model is ", sum(x.numel() for x in self.model.parameters()))
        except:
            print("NO Load")

        self.phis = np.linspace(0, np.pi, 300)

        self.target_cat = self.generate_cat_state(self.target_param[0]+1j*self.target_param[1])

    def calculate_fidelity(self, state1, state2):
        xvec = np.linspace(-15, 15, 401)
        W1 = state1.wigner(mode=0, xvec=xvec, pvec=xvec)
        W2 = state2.wigner(mode=0, xvec=xvec, pvec=xvec)
        return np.sum(W1 * W2 * 30 / 400 * 30 / 400) * 4 * np.pi

    def generate_cat_state(self, a, p=0):
        prog_cat = sf.Program(1)
        with prog_cat.context as q:
            sf.ops.Catstate(a=a, p=p) | q
        eng = sf.Engine("bosonic")
        cat = eng.run(prog_cat).state
        return cat

    def generate_cat_homodyne_prob(self, state, phi):
        return state.marginal(0, self.quad_axis, phi=phi)

    def _get_target_obs(self, phis_basis):

        v = torch.FloatTensor(self.phis).to(device)
        v = v.unsqueeze(dim=0)
        v = v.unsqueeze(dim=2)
        context_v = v[:, phis_basis]
        chosen_phi = self.phis[phis_basis]
        cat_prob = []
        # 每次都要根据目标态和给定的观测量得到概率分布
        for j in range(0, self.n_views):
            cat_prob.append(self.generate_cat_homodyne_prob(self.target_cat, chosen_phi[j]))

        agent_probs = np.expand_dims(np.array(cat_prob), axis=0)
        context_x = torch.FloatTensor(agent_probs).to(device)

        context_x = context_x.float()
        context_v = context_v.float()

        b, m, *x_dims = context_x.shape
        _, _, *v_dims = context_v.shape

        x = context_x.view((-1, *x_dims))
        v = context_v.view((-1, *v_dims))

        with torch.no_grad():
            phi = self.model.representation(x, v)
            _, *phi_dims = phi.shape
            phi = phi.view((b, m, *phi_dims))
            # sum over n_views to obtain representations
            neural_rep_target = torch.mean(phi, dim=1)

        return neural_rep_target.squeeze().cpu().numpy()

    def _get_agent_obs(self, next_state):
        """
        随机挑选30个observables
        """
        v = torch.FloatTensor(self.phis).to(device)
        v = v.unsqueeze(dim=0)
        v = v.unsqueeze(dim=2)

        batch_size, m, *_ = v.size()
        indices = list(range(0, m))
        np.random.shuffle(indices)
        representation_idx = indices[:self.n_views]
        context_v = v[:, representation_idx]
        chosen_phis = self.phis[representation_idx]

        agent_prob = []
        for j in range(0, self.n_views):
            agent_prob.append(self.generate_cat_homodyne_prob(next_state, chosen_phis[j]))

        agent_probs = np.expand_dims(np.array(agent_prob), axis=0)
        context_x = torch.FloatTensor(agent_probs).to(device)

        b, m, *x_dims = context_x.shape
        _, _, *v_dims = context_v.shape
        context_x = context_x.view((-1, *x_dims))
        context_v = context_v.view((-1, *v_dims))

        with torch.no_grad():
            phi = self.model.representation(context_x, context_v)
            _, *phi_dims = phi.shape
            phi = phi.view((b, m, *phi_dims))
            # sum over n_views to obtain representations
            neural_rep_agent = torch.mean(phi, dim=1)

        return neural_rep_agent.squeeze().cpu().numpy(), representation_idx

    def _get_info(self, agent_state):
        """
        :return: returns the distance of the target and agent hamiltonian cofficients
        """
        quantum_fidelity = self.calculate_fidelity(agent_state, self.target_cat)
        param_distance = np.linalg.norm(self.agent_param-self.target_param, ord=1)
        return {"quantum_fidelity": np.real(quantum_fidelity),
                "param_distance": param_distance,
                "agent_param_1": self.agent_param[0],
                "agent_param_2": self.agent_param[1]}

    def reset(self, seed=None):
        # We need the following line to seed self.np_random
        super().reset(seed=seed)
        agent_small_change = np.random.uniform(-0.05, 0.05, 2)
        # intial param is fixed, and agent small change is varied for each time
        self.agent_param = self.init_param + agent_small_change

        a_real = self.agent_param[0]  # -1-1
        a_img = self.agent_param[1]  # 0-1

        next_state = self.generate_cat_state(a_real+1j*a_img)
        initial_observation, basis_index = self._get_agent_obs(next_state)

        info = self._get_info(next_state)

        return initial_observation, info

    def step(self, action):
        # initial cat for
        action_change = np.zeros(2)
        for i, action_element in enumerate(action):
            action_element = int(action_element)
            direction = self._action_to_direction[action_element]
            action_change[i] = direction * self.action_step

        self.agent_param = np.clip(self.agent_param + action_change, -2, 2)
        a_real = self.agent_param[0]
        a_img = self.agent_param[1]
        next_state = self.generate_cat_state(a_real + 1j * a_img)
        agent_observation, basis_index = self._get_agent_obs(next_state)
        target_observation = self._get_target_obs(basis_index)
        normalized_euclidean_distance = (np.linalg.norm(agent_observation
                                                        - target_observation)
                                         / np.sqrt(agent_observation.shape[0]))
        reward = -1 * normalized_euclidean_distance * 10  # data range (-1, 0)

        try:
            info = self._get_info(next_state)
        except:
            info["quantum_fidelity"] = np.real(0)

        if reward == 0:
            done = 1
        else:
            done = 0

        truncated = False

        return agent_observation, reward, done, truncated, info


def test_env():
    env = QCVENV(target_param=np.array([1, 1]), initial_param=np.array([0.1, 0.1]))
    observation = env.reset()[0],
    agent_observation, reward, done, _, info = env.step([1, 0])

    print(reward)
    print(info["quantum_fidelity"])

# test_env()