import torch
import torch.optim as optim
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader
import itertools
from tqdm import tqdm

class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims=[256,256], final_activation=None):
        super(MLP, self).__init__()

        if final_activation == 'softmax':
            self.final_activation = nn.Softmax(dim=1)
        elif final_activation == 'tanh':
            self.final_activation = nn.Tanh()
        else:
            self.final_activation = None
        dimensions = [input_dim] + hidden_dims + [output_dim]
        layers = []
        for i in range(len(dimensions) - 1):
            layers.append(
                nn.Linear(dimensions[i], dimensions[i + 1])
            )
            layers.append(nn.ReLU())
        self.layers = nn.Sequential(*(layers[:-1]))

    def forward(self, x):
        x = self.layers(x)
        if self.final_activation is not None:
            return self.final_activation(x)
        return x


class Q_network(MLP):
    def __init__(self, input_dim, output_dim, hidden_dims=[256,256], final_activation=None):
        super(Q_network, self).__init__(input_dim, output_dim, hidden_dims, final_activation)

    def forward(self, states, actions):
        x = torch.cat((states, actions), dim=1)
        return super().forward(x)


class PDCA:
    def __init__(self, 
                 env,
                 data_collector, 
                 constraints = 0, 
                 batch_size = 32,
                 fast_lr = 0.0003,
                 slow_lr = 0.0001,
                 lambdas_lr = 0.0003,
                 B = 50,
                 ):
        
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device('cpu')
        offline_dataset = data_collector.get_onehot_encode_dataset()
        self.env = env
        self.dim_states = env.state_size
        self.dim_actions = env.action_size
        self.constraints = torch.tensor(constraints).to(self.device)
        self.num_constraints = len(constraints)
        self.batch_size = batch_size
        self.reward_network = Q_network(input_dim=(self.dim_states + self.dim_actions), output_dim=1).to(self.device)
        self.cost_network = Q_network(input_dim=(self.dim_states + self.dim_actions), output_dim=self.num_constraints).to(
            self.device)
        self.actor = MLP(input_dim=self.dim_states, output_dim=self.dim_actions, final_activation="tanh").to(self.device)
        self.fast_lr = fast_lr
        self.slow_lr = slow_lr
        self.lambdas_lr =  lambdas_lr
        self.critic_optimizer = optim.Adam(
            [
                {'params': self.reward_network.parameters()},
                {'params': self.cost_network.parameters()},
            ],
            lr=self.fast_lr,
            weight_decay=0.
        )
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.slow_lr, weight_decay=0.)
        self.gamma = env.gamma

        self.B = B
        if self.num_constraints > 1:
            self.dummy_lambda = None
            self.lambdas = torch.tensor([self.B / self.num_constraints] * self.num_constraints).to(self.device)
        else:
            self.dummy_lambda = torch.tensor([self.B / 2.]).to(self.device)
            self.lambdas = torch.tensor([self.B / 2.]).to(self.device)

        action = torch.tensor(offline_dataset['action'], dtype=torch.float32)
        observation = torch.tensor(offline_dataset['observation'], dtype=torch.float32)
        new_observation = torch.tensor(offline_dataset['new_observation'], dtype=torch.float32)
        cost = torch.tensor(offline_dataset['cost'], dtype=torch.float32).view(-1, 1)
        reward = torch.tensor(offline_dataset['reward'], dtype=torch.float32).view(-1, 1)
        goal = torch.tensor(offline_dataset['goal'], dtype=torch.float32).view(-1, 1)
        hole = torch.tensor(offline_dataset['hole'], dtype=torch.float32).view(-1, 1)
        is_init = torch.tensor(offline_dataset['is_init'], dtype=torch.float32).view(-1, 1)
        self.dataset = torch.utils.data.TensorDataset(observation, action, reward, cost, new_observation, goal, hole, is_init)

    def func_E(self, f, R, actor, states, actions, next_states, dones):
        X = f(states, actions) - R - self.gamma * f(next_states, actor(next_states)) * dones
        sum_positive = torch.mean(torch.clamp(X, min=0))
        sum_negative = torch.mean(torch.clamp(X, max=0))
        return torch.max(sum_positive, -sum_negative)

    def func_A(self, f, actor, states, actions):
        return torch.mean(f(states, actor(states)) - f(states, actions))

    def combined_reward(self, states, actions):
        return self.reward_network(states, actions) + (
                self.constraints - self.cost_network(states, actions)) * self.lambdas

    def online_update(self, initial_states):
        with torch.no_grad():
            w = torch.mean(self.constraints - self.cost_network(initial_states, self.actor(initial_states)), axis=0)
            exp_w = torch.exp(-self.lambdas_lr * w)
            lambdas = self.lambdas * exp_w

            if self.dummy_lambda:
                self.lambdas = self.B * lambdas / (lambdas + self.dummy_lambda)
                self.dummy_lambda = self.B * self.dummy_lambda / (lambdas + self.dummy_lambda)
            else:
                self.lambdas = self.B * lambdas / torch.sum(lambdas)

    def update(self, batch):
        states, actions, rewards, costs, next_states, done, hole, is_init = batch

        states = torch.from_numpy(np.stack(states)).float().to(self.device)
        actions = torch.from_numpy(np.stack(actions)).float().to(self.device)
        rewards = torch.from_numpy(np.stack(rewards)).float().unsqueeze(1).to(self.device)
        costs = torch.from_numpy(np.stack(costs)).float().unsqueeze(1).to(self.device)
        next_states = torch.from_numpy(np.stack(next_states)).float().to(self.device)
        dones = torch.from_numpy(np.stack(dones)).unsqueeze(1).to(self.device)
        initial_states = torch.from_numpy(np.stack(initial_states)).float().to(self.device)

        self.critic_optimizer.zero_grad()

        loss_reward = 2 * self.func_E(self.reward_network, rewards, self.actor, states, actions, next_states, dones) + \
                      self.func_A(self.reward_network, self.actor, states, actions)
        loss_costs = 2 * self.func_E(self.cost_network, costs, self.actor, states, actions, next_states, dones) - \
                     self.func_A(self.cost_network, self.actor, states, actions)
        loss_critic = loss_reward + loss_costs
        loss_critic.backward()
        nn.utils.clip_grad_norm_(self.reward_network.parameters(), max_norm=1.0)
        nn.utils.clip_grad_norm_(self.cost_network.parameters(), max_norm=1.0)
        self.critic_optimizer.step()

        self.actor_optimizer.zero_grad()
        loss_actor = -self.func_A(self.combined_reward, self.actor, states, actions)

        loss_actor.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=1.0)
        self.actor_optimizer.step()

        self.online_update(initial_states)
    
    def train(self):
        trainloader = itertools.cycle(DataLoader(self.dataset, batch_size=32, shuffle=True))
        trainloader_iter = iter(trainloader)
        test_goalrate = []
        with tqdm(total=self.epochs) as pbar:
            for epoch in range(self.epochs):
                batch = next(trainloader_iter)
                batch = [b.to(self.device) for b in batch]

                
                w_model, w_sa, nu_loss, td_error, chi_loss, tau_loss, lmbda_loss, lmbda, tau = self.model.update(batch)

                if (epoch+1) % self.test_ever == 0:
                    # print(f"Epoch: {epoch}, Mean of w: {w_sa.item()}, Nu Loss: {nu_loss}, TD error: {td_error} Chi Loss: {chi_loss}, Tau Loss: {tau_loss}, Lambda Loss: {lmbda_loss}, Lambda: {lmbda}, Tau: {tau}")
                    w_model.eval()
                    w_s_a = self.optimal_w(w_model, lmbda)
                    policy = self.policy_extraction(w_s_a)
                    goalrate = test(self.env, policy, self.test_episode)
                    # print(goalrate)
                    test_goalrate.append(goalrate)
                    w_model.train()
                pbar.update(1)
        self.logger['test_goalrate'] = test_goalrate
        return self.policy_extraction(w_s_a)        


# def eval(env, alg, test_epoch):
#     with torch.no_grad():
#         epoch_reward = 0
#         epoch_cost = 0
#         for e in range(test_epoch):
#             state = env.reset()
#             done = False
#             t = 0

#             while not done:
#                 action = alg.policy(
#                     torch.tensor(state[:-env.num_constraints], device=alg.device, dtype=torch.float32).unsqueeze(0))
#                 next_state, reward, done, info, cost = env.log_step(action.detach().cpu().squeeze().numpy())
#                 epoch_reward += reward
#                 epoch_cost += cost
#                 t += 1
#                 state = next_state

#     return epoch_reward / test_epoch, epoch_cost / test_epoch


# def run(conf):
#     env_name = conf['env']['env_name']
#     task_name = conf['env']['task_name']
#     safety_coeff = conf['env']['safety_coeff']
#     env = make_rwrl_env(env_name, task_name, safety_coeff)
#     env = CustomEnv(env, env_name, lamb=0)
#     dim_states = env.observation_space.shape[0] - env.num_constraints
#     dim_actions = env.action_space.shape[0]

#     constraints = [conf['env']['constraints']]
#     max_epoch = conf['alg']['max_epoch']
#     dataset = Dataset(capacity=3000000)

#     dataset.load_from_files(conf['alg']['dataset'])
#     alg = PDCA(dim_states, dim_actions, constraints, dataset=dataset, conf=conf['alg'])
#     for n in range(max_epoch):
#         alg.update()

#         if (n + 1) % 1000 == 0:
#             res = eval(env, alg, 3)
#             print("Reward: ", res[0], "Cost: ", res[1])
