from typing import List
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, optim, Tensor
import copy

from config import config


class Embedding(nn.Module):
    def __init__(self, obs_size, num_outputs, Mean_in, max_in_dis, min_in_dis):
        super(Embedding, self).__init__()
        self.obs_size = obs_size
        self.num_outputs = num_outputs
        self.device = config.device
        self.latent_dim = 32
        self.fc1 = nn.Linear(obs_size, self.latent_dim).to(self.device)
        self.fc2 = nn.Linear(self.latent_dim, self.latent_dim).to(self.device)
        self.last = nn.Linear(self.latent_dim * 2, num_outputs).to(self.device)

        self.optimizer = optim.Adam(self.parameters(), lr=1e-5)
        self.vae_list = []
        self.max_loss = 10000
        self.losses = np.zeros(self.max_loss)
        self.mean_in_dis = Mean_in
        self.max_in_dis = max_in_dis
        self.min_in_dis = min_in_dis
        self.point = 0

    def embedding(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x

    def forward(self, x1, x2):
        x1 = self.embedding(x1)
        x2 = self.embedding(x2)
        # print(x1.shape)
        x = torch.cat([x1, x2], dim=1)
        x = self.last(x)
        return nn.Softmax(dim=1)(x)

    def train_model(self, batch):
        (
            states,
            actions,
            rewards,
            next_states,
            dones,
        ) = batch

        # todo: find out why only use last 5 in sequence
        self.optimizer.zero_grad()
        out = self.forward(states, next_states)
        # print(actions)
        # actions_onehot = torch.squeeze(F.one_hot(actions, self.num_outputs)).float()
        loss = nn.MSELoss()(out, actions)
        self.add_to_losses(loss)
        loss.backward()
        self.optimizer.step()
        log_dict = {"loss/loss_embedding": loss.item(),
                    "output": out}
        return log_dict

    @torch.no_grad()
    def add_to_losses(self, loss):
        self.losses[self.point] = loss.cpu().numpy()
        self.point += 1
        if self.point == self.max_loss:
            self.point = 0

    @torch.no_grad()
    def get_prob(self,
                 states,
                 next_states,
                 actions):
        out = self.forward(states, next_states)
        losses = torch.mean(torch.square(out - actions), dim=1)
        # print(out, '\n', losses)
        # print(losses)
        for i in range(losses.shape[0]):
            if losses[i] <= self.min_in_dis:
                losses[i] = (self.min_in_dis - losses[i]) / (self.max_in_dis - self.min_in_dis) * 2
                # print('1', losses[i])
            elif losses[i] >= self.max_in_dis:
                # print(self.max_in_dis, self.min_in_dis)
                losses[i] = (losses[i] - self.max_in_dis) / (self.max_in_dis - self.min_in_dis) * 2
                # print('2', losses[i])
            else:
                losses[i] = 0
            losses[i] = torch.min(torch.tensor([1, losses[i]]))
            # losses[i] = 1 - torch.min(torch.tensor([1, torch.abs(losses[i] - 0.515)/0.3]))
        print(losses)
        return losses


    def test_one_step(self, states, next_states, actions, rewards, dones):
        # update VAE

        with torch.no_grad():
            out = self.forward(states, next_states)
            # print(actions)
            # actions_onehot = torch.squeeze(F.one_hot(actions, self.num_outputs)).float()
            loss = nn.MSELoss()(out, actions)
            self.vae_list.append(loss)

def compute_intrinsic_reward(
    episodic_memory: List,
    current_c_state: Tensor,
    k=10,
    kernel_cluster_distance=0.008,
    kernel_epsilon=0.0001,
    c=0.001,
    sm=8,
) -> float:
    state_dist = [(c_state, torch.dist(c_state, current_c_state)) for c_state in episodic_memory]
    state_dist.sort(key=lambda x: x[1])
    state_dist = state_dist[:k]
    dist = [d[1].item() for d in state_dist]
    dist = np.array(dist)

    dist = dist / np.mean(dist)

    dist = np.max(dist - kernel_cluster_distance, 0)
    kernel = kernel_epsilon / (dist + kernel_epsilon)
    s = np.sqrt(np.sum(kernel)) + c

    if np.isnan(s) or s > sm:
        return 0
    else:
        return 1 / s


class RND_net(nn.Module):
    def __init__(self, obs_size, action_size, num_outputs):
        super(RND_net, self).__init__()
        self.obs_size = obs_size
        self.action_size = action_size
        self.device = config.device
        self.latent_dim = 64
        self.num_outputs = num_outputs
        self.net1 = nn.Linear(self.action_size + self.obs_size, self.latent_dim).to(self.device)
        self.net2 = nn.Linear(self.latent_dim, self.latent_dim * 2).to(self.device)
        self.net3 = nn.Linear(self.latent_dim * 2, self.latent_dim).to(self.device)
        self.net4 = nn.Linear(self.latent_dim, self.num_outputs).to(self.device)

    def forward(self, states, actions):
        x = F.relu(self.net1(torch.cat([states, actions], 1)))
        x = F.relu(self.net2(x))
        x = F.relu(self.net3(x))
        x = F.relu(self.net4(x))

        return x


class RND_trainer:
    def __init__(self, obs_size, action_size, num_outputs, ):
        self.obs_size = obs_size
        self.action_size = action_size
        self.device = config.device
        self.latent_dim = 64
        self.num_outputs = num_outputs

        self.rnd_net = RND_net(self.obs_size, self.action_size, self.num_outputs)
        self.fix_rnd_net = copy.deepcopy(self.rnd_net).requires_grad_(False).to(self.device)
        self.optimizer = optim.Adam(self.rnd_net.parameters(), lr=1e-5)
        self.vae_list = []

    def train_model(self, batch):
        (
            states,
            actions,
            rewards,
            next_states,
            dones,
        ) = batch
        self.optimizer.zero_grad()
        change_out = self.rnd_net(states, actions)
        with torch.no_grad():
            fix_out = self.fix_rnd_net(states, actions)

        # print(actions)
        # actions_onehot = torch.squeeze(F.one_hot(actions, self.num_outputs)).float()
        # loss = nn.MSELoss()(out, actions)
        # loss.backward()
        # self.optimizer.step()
        # log_dict = {"loss/loss_embedding": loss.item(),
        #             "output": out}
        # return log_dict




