import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from torch.optim import Adam

from utils_psd import get_minibatch, onehot2radius
import os
import numpy as np
import time

# Initialize Policy weights
def weights_init_(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
        torch.nn.init.constant_(m.bias, 0)

class Psi(nn.Module):
    def __init__(self, num_inputs, args):
        super(Psi, self).__init__()
        
        self.lr = args.lr
        self.skill_dim = args.radius_latent_dim
        self.hidden_dim = args.hidden_size
        self.device = torch.device("cuda")
        
        # Psi architecture
        self.linear1 = nn.Linear(num_inputs, self.hidden_dim)
        self.linear2 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.linear3 = nn.Linear(self.hidden_dim, self.skill_dim)

        self.apply(weights_init_)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)

    def forward(self, state):
        
        x1 = F.relu(self.linear1(state))
        x1 = F.relu(self.linear2(x1))
        x1 = self.linear3(x1)

        return x1

    def forward_np(self, state):
        
        state = torch.from_numpy(state).float().to(self.device)
        # state = state[:-self.skill_dim]
        x1 = F.relu(self.linear1(state))
        x1 = F.relu(self.linear2(x1))
        x1 = self.linear3(x1)

        return x1.detach().cpu().numpy()
    
    def update_parameters(self, memory_traj, args):

        lambda_value_1 = torch.tensor(args.lambda_value_1, device=self.device, dtype=torch.float32)
        lambda_value_L = torch.tensor(args.lambda_value_L, device=self.device, dtype=torch.float32)
        epsilon = torch.tensor(args.epsilon, device=self.device, dtype=torch.float32)

        # dim : [minibatch, len_traj, feature], [minibatch, len_traj]
        states_batch, radius_batch = memory_traj.sample(args.traj_batch_size)

        # dim : [minibatch, feature], [minibatch, feature], [minibatch, feature], [minibatch]
        minibatch_before, minibatch_before_prime, minibatch_after, minibatch_radius = get_minibatch(states_batch, radius_batch, args)

        # If a trajectory that meets the conditions is not sampled, do not update
        if minibatch_before.size == 0:
            return 0, 0, 0, 0, 0, False

        # Numpy to tensor
        minibatch_before = torch.from_numpy(minibatch_before).to(self.device)
        minibatch_before_prime = torch.from_numpy(minibatch_before_prime).to(self.device)
        minibatch_after = torch.from_numpy(minibatch_after).to(self.device)
        L = torch.from_numpy(minibatch_radius).to(self.device)

        # Calculate loss
        psi_before = self.forward(minibatch_before)
        psi_before_prime = self.forward(minibatch_before_prime)
        psi_after = self.forward(minibatch_after)

        loss_max = -torch.norm(psi_after-psi_before, p=2, dim=-1)
        loss_min = torch.norm((psi_after+psi_before)/2, p=2, dim=-1)
        loss_const_1 = -lambda_value_L * torch.min(epsilon, L-torch.norm(psi_after-psi_before, p=2, dim=-1))
        loss_const_2 = -lambda_value_1 * torch.min(epsilon, L*torch.sin(np.pi/(2*L)) - torch.norm(psi_before_prime-psi_before, p=2, dim=-1))
        
        loss = loss_max + loss_min + loss_const_1 + loss_const_2

        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step()

        return loss.mean().item(), loss_max.mean().item(), loss_min.mean().item(), loss_const_1.mean().item(), loss_const_2.mean().item(), True


