import cvxpy as cp
import numpy as np
import torch
import torch.nn as nn
from lpcmdp.algorithm.utils import *
# from tqdm.notebook import tqdm
from tqdm import tqdm
from lpcmdp.algorithm.model import *
import warnings
# from environment import FrozenLakeEnv, FrozenLakeEnv_nocost
# from datacollector import datacollect, get_behave_policy_prob

warnings.filterwarnings("ignore")


class QCritic(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, num_q=1, mean=0.01, var=0.01):
        super().__init__()
        self.q_net = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], nn.ReLU, output_activation=activation)
        self.net_mean = mean
        self.net_var = var
        
        for layer in self.q_net:
            if isinstance(layer, torch.nn.Linear):
        #         # nn.init.xavier_normal_(layer.weight)
                nn.init.normal_(layer.weight, self.net_mean, self.net_var)

    def forward(self, obs, act=None):
        data = obs if act is None else torch.cat([obs, act], dim=-1)
        return self.q_net(data)

def ImportanceSamplingDiscreteSolver(
        env,
        datacollector,
        cost_threshold,
        E_n_delta,
        E_n_delta_k,
        ):
    
    offline_dataset = datacollector.get_original_dataset()
    env = env
    mu0 = np.zeros((env.state_size))
    mu0[0] = 1
    mu_D_count = np.zeros(env.state_size * env.action_size)
    r_s_a = np.zeros(env.state_size*env.action_size)
    c_s_a = np.zeros(env.state_size*env.action_size)
    M = np.zeros((env.state_size, env.state_size*env.action_size))
    P = np.zeros((env.state_size, env.state_size*env.action_size))
    behav_policy_prob = datacollector.get_behave_policy_prob()
    for i in range(len(offline_dataset['observation'])):
        s = offline_dataset['observation'][i]
        a = offline_dataset['action'][i]
        ns = offline_dataset['new_observation'][i]
        r = offline_dataset['reward'][i]
        c = offline_dataset['cost'][i]
        s_a = s * env.action_size + a
        mu_D_count[s_a] += 1
        r_s_a[s_a] = r
        c_s_a[s_a] = c
        P[ns, s_a] += 1
    mu_D_s_a = mu_D_count / mu_D_count.sum()
    for i in range(M.shape[0]):
        M[i, i*env.action_size:(i+1)*env.action_size] = np.ones(env.action_size)
    col_sums = P.sum(axis=0)
    P[:, col_sums != 0] /= col_sums[col_sums != 0]
    M = M - env.gamma*P
    K_D = M * mu_D_s_a
    u_D = r_s_a * mu_D_s_a
    h_D = c_s_a * mu_D_s_a

    x = np.ones(env.state_size)
    w = cp.Variable(len(u_D))
    objective = cp.Maximize(u_D.T @ w)
    constraints = [
        cp.norm(K_D @ w - (1-env.gamma)*mu0) <= E_n_delta,
        h_D.T @ w - cost_threshold <= E_n_delta_k,
        w >= 0
    ]
    problem = cp.Problem(objective, constraints)
    problem.solve()
    w_solve = w.value.reshape((env.state_size, env.action_size))
    w_s_a = w_solve.reshape(env.state_size, env.action_size)
    w = w_s_a.reshape((env.state_size, env.action_size))
    pi_s_a = np.zeros_like(w)
    for i in range(w.shape[0]):
        for j in range(w.shape[1]):
            if np.sum(w[i, :] * behav_policy_prob[i, :]) == 0:
                pi_s_a[i][j] = 0.0
            else:
                pi_s_a[i][j] = (w[i][j] * behav_policy_prob[i][j]) / np.sum(w[i, :] * behav_policy_prob[i, :])
    return pi_s_a


class ImportanceSamplingApproximateSolver():

    def __init__(self,
                 env,
                 datacollector,
                 cost_threshold,
                 E_n_delta,
                 E_n_delta_k,
                 epoches = 80000,
                 critic_lr = 0.00001,
                 lambda_lr = 0.0001,
                 tau_lr = 0.0001,
                 device = 'cpu',
                 test_ever = 1000,
                 test_episode = 100,
                 behavior_policy_style = 'real'
                 ) -> None:
        offline_dataset = datacollector.get_original_dataset()
        self.env = env
        self.mu0 = np.zeros((env.state_size))
        self.mu0[0] = 1
        self.mu_D_count = np.zeros(env.state_size * env.action_size)
        self.r_s_a = np.zeros(env.state_size*env.action_size)
        self.c_s_a = np.zeros(env.state_size*env.action_size)
        self.M = np.zeros((env.state_size, env.state_size*env.action_size))
        self.P = np.zeros((env.state_size, env.state_size*env.action_size))
        self.real_behav_policy = datacollector.get_real_behave_policy()
        self.behav_policy_prob = datacollector.get_behave_policy_prob()
        self.behavior_policy_style = behavior_policy_style
        for i in range(len(offline_dataset['observation'])):
            s = offline_dataset['observation'][i]
            a = offline_dataset['action'][i]
            ns = offline_dataset['new_observation'][i]
            r = offline_dataset['reward'][i]
            c = offline_dataset['cost'][i]
            s_a = s * env.action_size + a
            self.mu_D_count[s_a] += 1
            self.r_s_a[s_a] = r
            self.c_s_a[s_a] = c
            self.P[ns, s_a] += 1
        self.mu_D_s_a = self.mu_D_count / self.mu_D_count.sum()
        for i in range(self.M.shape[0]):
            self.M[i, i*env.action_size:(i+1)*env.action_size] = np.ones(env.action_size)
        col_sums = self.P.sum(axis=0)
        self.P[:, col_sums != 0] /= col_sums[col_sums != 0]
        self.M = self.M - env.gamma*self.P
        self.K_D = self.M * self.mu_D_s_a
        self.u_D = self.r_s_a * self.mu_D_s_a
        self.h_D = self.c_s_a * self.mu_D_s_a
        self.epoches = epoches
        self.critic_lr = critic_lr
        self.lambda_lr = lambda_lr
        self.tau_lr = tau_lr
        self.cost_threshold = cost_threshold
        self.E_n_delta = E_n_delta
        self.E_n_delta_k = E_n_delta_k
        self.device =  device
        self.test_ever = test_ever
        self.test_episode = test_episode
        self.logger = {}

    def train(self):
        # obs_encode, acts_encode = Input_Encode('one_hot', self.env.state_size, self.env.action_size)

        obs_encode, acts_encode = self.env.state_action_onehot_encode()
        
        obs_encode, acts_encode = obs_encode.to(self.device), acts_encode.to(self.device)
        h_D_tensor = torch.from_numpy(self.h_D).reshape(-1, 1).to(torch.float32).to(self.device)
        u_D_tensor = torch.from_numpy(self.u_D).reshape(-1, 1).to(torch.float32).to(self.device)
        K_D_tensor = torch.from_numpy(self.K_D).to(torch.float32).to(self.device)
        mu_0_tensor = torch.from_numpy(self.mu0).to(torch.float32).reshape(-1, 1).to(self.device)
        obs_encode_dim = obs_encode.shape[1]
        acts_encode_dim = acts_encode.shape[1]
        critic_activa_func = nn.ReLU 
        num_q = 1
        hidden_size = [64, 64]
        w_network = QCritic(obs_encode_dim, acts_encode_dim, hidden_size, critic_activa_func, num_q, 0.05, 0.01).to(self.device)
        lmbda = torch.tensor([0.7], requires_grad=True, device=self.device)
        tau = torch.tensor([1.0], requires_grad=True, device=self.device)
        w_optim = torch.optim.Adam(w_network.parameters(), lr=self.critic_lr)
        lmbda_optim = torch.optim.Adam([lmbda], lr=self.lambda_lr)
        tau_optim = torch.optim.Adam([tau], lr=self.tau_lr)
        w_loss_list, lmbda_loss_list, tau_loss_list= [], [], []
        w_list, lmbda_list, tau_list = [], [], []
        lmbda_grad_list, w_grad_list = [], []
        w_s_a_list = []
        best_epoch, best_policy = 0, 0
        test_reward, test_cost = [], []

        with tqdm(total=self.epoches, desc="Importance Sampling Approximate Training", leave=False) as pbar:
            for epoch in range(self.epoches):
                w_s_a = w_network(obs_encode, acts_encode).reshape(-1, 1)
                w_s_a_nograd = w_s_a.detach()
                w_loss = - torch.mm(u_D_tensor.reshape(1, -1), w_s_a) \
                            + lmbda.detach() * (torch.norm(K_D_tensor@w_s_a - (1-self.env.gamma)*mu_0_tensor, 1) ) \
                            + tau.detach() * (torch.mm(h_D_tensor.reshape(1, -1), w_s_a))
                w_optim.zero_grad()
                w_loss.backward()
                w_optim.step()
                w_loss_list.append(w_loss.sum().item())
                w_list.append(w_s_a_nograd.mean().item())
                w_s_a_list.append(w_s_a_nograd.tolist())
                sum_grad = sum_net_grad(w_network)
                w_grad_list.append(sum_grad)
                lmbda_loss = lmbda * (self.E_n_delta - torch.norm(K_D_tensor@w_s_a_nograd - (1-self.env.gamma)*mu_0_tensor, 1))
                lmbda_optim.zero_grad()
                lmbda_loss.backward()
                lmbda_optim.step()
                lmbda_grad_list.append(lmbda.grad.item())
                if lmbda.item() < 0:
                    with torch.no_grad():
                        lmbda.clamp_(min=0)
                lmbda_loss_list.append(lmbda_loss.item())
                lmbda_list.append(lmbda.item())
                tau_loss = tau * (self.cost_threshold + self.E_n_delta_k - torch.mm(h_D_tensor.reshape(1, -1), w_s_a_nograd))
                tau_optim.zero_grad()
                tau_loss.backward()
                tau_optim.step()
                if tau.item() < 0:
                    with torch.no_grad():
                        tau.clamp_(min=0)
                tau_list.append(tau.item())
                tau_loss_list.append(tau_loss.item())
                best_policy = 0.0
                if (epoch+1) % self.test_ever == 0:
                    policy = self.policy_extraction(w_s_a_nograd.cpu().numpy(), behavior_policy_style=self.behavior_policy_style) 
                    reward, cost = test(env=self.env, policy=policy, test_episode=self.test_episode)
                    test_reward.append(reward)
                    test_cost.append(cost)
                    if best_epoch == 0 or reward > best_policy:
                        best_epoch = epoch
                        best_policy = reward
                        best_w = w_s_a_nograd.cpu().numpy()
                    # self.env.plot_policy(policy)
                    # print(f'Epoch: {epoch}, W_loss: {w_loss_list[-1]}, Best epoch: {best_epoch}, Best Goal rate: {best_policy}, W_value_mean: {w_list[-1]}, lmbda: {lmbda_list[-1]}, tau: {tau_list[-1]}')
                pbar.update(1)
        self.logger['test_reward'] = test_reward
        self.logger['test_cost'] = test_cost
        return self.policy_extraction(best_w, behavior_policy_style=self.behavior_policy_style)
    
    def policy_extraction(self, w_s_a, behavior_policy_style='real'):
        w = w_s_a.reshape((self.env.state_size, self.env.action_size))
        pi_s_a = np.zeros_like(w)
        if behavior_policy_style == 'real':
            behav_policy = self.real_behav_policy
        else:
            behav_policy = self.behav_policy_prob
        for i in range(w.shape[0]):
            for j in range(w.shape[1]):
                if np.sum(w[i, :] * behav_policy[i, :]) == 0:
                    pi_s_a[i][j] = 0.0
                else:
                    pi_s_a[i][j] = (w[i][j] * behav_policy[i][j]) / np.sum(w[i, :] * behav_policy[i, :])
        return pi_s_a
    
    def get_logger(self):
        return self.logger
    

