import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from torch.optim import Adam

from utils_psd import get_minibatch, onehot2radius
import os
import numpy as np
import time

# Initialize Policy weights
def weights_init_(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
        torch.nn.init.constant_(m.bias, 0)

# PSI encoder
class CNNPSIEncoder(nn.Module):
    def __init__(self, args, input_channels=3*3):
        super(CNNPSIEncoder, self).__init__()
        # 3*3 channel image (CHW*frame)
        self.conv1 = nn.Conv2d(input_channels, 1*args.cnn_depth, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(1*args.cnn_depth, 2*args.cnn_depth, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(2*args.cnn_depth, 4*args.cnn_depth, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(4*args.cnn_depth, 8*args.cnn_depth, kernel_size=4, stride=2, padding=1)

        self.conv_output_size = 8 * args.cnn_depth * 5 * 5

        self.apply(weights_init_)

    def forward(self, x):
        x = x / 255.

        x = F.elu(self.conv1(x))
        x = F.elu(self.conv2(x))
        x = F.elu(self.conv3(x))
        x = F.elu(self.conv4(x))

        x = x.reshape(x.size(0), -1)  # flatten while keeping batch dim

        return x



class Psi(nn.Module):
    def __init__(self, args):
        super(Psi, self).__init__()
        
        self.lr = args.lr
        self.skill_dim = args.radius_latent_dim
        self.hidden_dim = args.hidden_size
        self.device = torch.device("cuda")
        
        self.conv_output_size = 8 * args.cnn_depth * 5 * 5

        # Psi architecture
        self.linear1 = nn.Linear(self.conv_output_size + args.radius_input_dim, self.hidden_dim)
        self.linear2 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.linear3 = nn.Linear(self.hidden_dim, self.skill_dim)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)

        # CNN encoder
        self.CNNencoder = CNNPSIEncoder(args)

        # Weight init
        self.apply(weights_init_)

    def forward(self, state, radius_input):
        
        cnn_feature = self.CNNencoder(state)
        state = torch.cat((cnn_feature, radius_input), dim=-1)

        x1 = F.relu(self.linear1(state))
        x1 = F.relu(self.linear2(x1))
        x1 = self.linear3(x1)

        return x1

    def forward_np(self, state, radius_input):
        state = torch.from_numpy(state.copy()).float().to(self.device)
        radius_input = torch.from_numpy(radius_input).float().to(self.device)

        # add batch dim to single dim data
        state = state.unsqueeze(0)
        radius_input = radius_input.unsqueeze(0)  
        
        cnn_feature = self.CNNencoder(state)
        state = torch.cat((cnn_feature, radius_input), dim=-1)

        x1 = F.relu(self.linear1(state))
        x1 = F.relu(self.linear2(x1))
        x1 = self.linear3(x1)

        x1 = x1.squeeze(0)  # remove batch dim

        return x1.detach().cpu().numpy()
    

    def update_parameters(self, memory_traj, args):

        lambda_value = torch.tensor(args.lambda_value, device=self.device, dtype=torch.float32)
        epsilon = torch.tensor(args.epsilon, device=self.device, dtype=torch.float32)

        # dim : [minibatch, len_traj, feature], [minibatch, len_traj]
        states_batch, radius_input_batch, radius_batch = memory_traj.sample(args.traj_batch_size)

        # dim : [minibatch, feature], [minibatch, feature], [minibatch, feature], [minibatch, feature], [minibatch]
        minibatch_before, minibatch_before_prime, minibatch_after, minibatch_radius_input, minibatch_radius = get_minibatch(states_batch, radius_input_batch, radius_batch, args)

        # If a trajectory that meets the conditions is not sampled, do not update
        if minibatch_before.size == 0:
            return 0, 0, 0, 0, 0, False
        
        # Numpy to tensor
        minibatch_before = torch.from_numpy(minibatch_before).to(self.device)
        minibatch_before_prime = torch.from_numpy(minibatch_before_prime).to(self.device)

        minibatch_after = torch.from_numpy(minibatch_after).to(self.device)

        minibatch_radius_input = torch.from_numpy(minibatch_radius_input).to(self.device)
        L = torch.from_numpy(minibatch_radius).to(self.device)

        # Calculate network
        psi_before = self.forward(minibatch_before, minibatch_radius_input)
        psi_before_prime = self.forward(minibatch_before_prime, minibatch_radius_input)
        psi_after = self.forward(minibatch_after, minibatch_radius_input)

        loss_max = -torch.norm(psi_after-psi_before, p=2, dim=-1)
        loss_min = torch.norm((psi_after+psi_before)/2, p=2, dim=-1)
        loss_const_1 = -lambda_value * torch.min(epsilon, L-torch.norm(psi_after-psi_before, p=2, dim=-1))
        loss_const_2 = -lambda_value * torch.min(epsilon, L*torch.sin(np.pi/(2*L)) - torch.norm(psi_before_prime-psi_before, p=2, dim=-1))
        
        loss = loss_max + loss_min + loss_const_1 + loss_const_2

        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step()

        return loss.mean().item(), loss_max.mean().item(), loss_min.mean().item(), loss_const_1.mean().item(), loss_const_2.mean().item(), True
    

