import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from torch.optim import Adam

from utils_psd import get_minibatch, onehot2radius
import os
import numpy as np
import time

# Initialize Policy weights
def weights_init_(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
        torch.nn.init.constant_(m.bias, 0)

class Psi(nn.Module):
    def __init__(self, num_inputs, args):
        super(Psi, self).__init__()
        
        self.lr = args.lr
        self.psd_skill_dim = args.radius_latent_dim
        self.metra_skill_dim = args.metra_skill_dim
        self.hidden_dim = args.hidden_size
        self.device = torch.device("cuda")
        
        input_dims = [
            num_inputs - self.psd_skill_dim,     # (obs, metra_skill)
            num_inputs - self.metra_skill_dim,   # (obs, psd_skill)
        ]

        split_dims = [self.hidden_dim//2, self.hidden_dim//2]
        
        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, self.hidden_dim),
                nn.ReLU(),
                nn.Linear(self.hidden_dim, split_dim),
                nn.ReLU()
            ) for input_dim, split_dim in zip(input_dims, split_dims)
        ])
        
        self.final = nn.Sequential(
            nn.Linear(sum(split_dims), self.psd_skill_dim)
        )
    
        self.apply(weights_init_)
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)

    def forward(self, state):
        
        obs = state[:, :-(self.metra_skill_dim + self.psd_skill_dim)]
        psd_skill = state[:, -(self.metra_skill_dim + self.psd_skill_dim):-self.metra_skill_dim]
        metra_skill = state[:, -self.metra_skill_dim:]
        
        inputs = [
            torch.cat([obs, metra_skill], dim=-1),     # (obs, metra_skill)
            torch.cat([obs, psd_skill], dim=-1)        # (obs, psd_skill)
        ]
        
        x = [mlp(input_) for mlp, input_ in zip(self.layers, inputs)]
        x = torch.cat(x, dim=-1)
        x = self.final(x)

        return x


    def forward_np(self, state):

        state = torch.from_numpy(state).float().to(self.device)

        obs = state[:-(self.metra_skill_dim + self.psd_skill_dim)]
        psd_skill = state[-(self.metra_skill_dim + self.psd_skill_dim):-self.metra_skill_dim]
        metra_skill = state[-self.metra_skill_dim:]
        
        inputs = [
            torch.cat((obs, metra_skill), dim=-1),
            torch.cat((obs, psd_skill), dim=-1)
        ]
        
        x = [mlp(input_) for mlp, input_ in zip(self.layers, inputs)]
        x = torch.cat(x, dim=-1)
        x = self.final(x)
        
        return x.detach().cpu().numpy()


    def update_parameters(self, memory_traj, args):

        # lambda_value = args.lambda_value
        # epsilon = args.epsilon

        lambda_value = torch.tensor(args.lambda_value, device=self.device, dtype=torch.float32)
        epsilon = torch.tensor(args.epsilon, device=self.device, dtype=torch.float32)

        # dim : [minibatch, len_traj, feature], [minibatch, len_traj]
        states_batch, radius_batch = memory_traj.sample(args.traj_batch_size)

        # dim : [minibatch, feature], [minibatch, feature], [minibatch, feature], [minibatch]
        minibatch_before, minibatch_before_prime, minibatch_after, minibatch_radius = get_minibatch(states_batch, radius_batch, args)
        
        # If a trajectory that meets the conditions is not sampled, do not update
        if minibatch_before.size == 0:
            return 0, 0, 0, 0, 0, False

        # Numpy to tensor
        minibatch_before = torch.from_numpy(minibatch_before).to(self.device)
        minibatch_before_prime = torch.from_numpy(minibatch_before_prime).to(self.device)
        minibatch_after = torch.from_numpy(minibatch_after).to(self.device)
        L = torch.from_numpy(minibatch_radius).to(self.device)

        psi_before = self.forward(minibatch_before[:,args.pos_dim:])
        psi_before_prime = self.forward(minibatch_before_prime[:,args.pos_dim:])
        psi_after = self.forward(minibatch_after[:,args.pos_dim:])

        loss_max = -torch.norm(psi_after-psi_before, p=2, dim=-1)
        loss_min = torch.norm((psi_after+psi_before)/2, p=2, dim=-1)
   
        loss_const_1 = -lambda_value * torch.min(epsilon, L-torch.norm(psi_after-psi_before, p=2, dim=-1))
        loss_const_2 = -lambda_value * torch.min(epsilon, L*torch.sin(np.pi/(2*L)) - torch.norm(psi_before_prime-psi_before, p=2, dim=-1))
        loss = loss_max + loss_min + loss_const_1 + loss_const_2

        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step()

        return loss.mean().item(), loss_max.mean().item(), loss_min.mean().item(), loss_const_1.mean().item(), loss_const_2.mean().item(), True
