#!/usr/bin/env python
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import math

device = torch.device("cpu")

class Discriminator(nn.Module):    
    def __init__(self, inp, out):        
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
                                 nn.Linear(inp,64),
                                 nn.LeakyReLU(),
                                 nn.Linear(64,64),
                                 nn.LeakyReLU(),
                                 nn.Linear(64,out),
                                 nn.Sigmoid()
                                    )
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.01)

    def forward(self, x):
        x = self.net(x)
        return x



class DensityRatio():
    def __init__(self, state_dim, out=1):
        self.disc = Discriminator(state_dim,out).to(device)


    def train(self,expert,agent,epochs=2,batch_size=1000):
        for i in  range(epochs):
            epoch_loss=0
            for ctr in range(int(expert.shape[0]/batch_size)):
                expert_batch = torch.from_numpy(expert[ctr*batch_size:(ctr+1)*batch_size,:]).float()
                agent_batch = torch.from_numpy(agent).float()
                self.disc.optimizer.zero_grad()

                loss = nn.BCELoss()
                
                agent_out = self.disc(agent_batch)
                agent_loss = loss(agent_out,Variable(torch.ones(agent_out.shape[0],1)))

                agent_loss.backward()

                expert_out = self.disc(expert_batch)

                expert_loss = loss(expert_out,Variable(torch.zeros(expert_out.shape[0],1)))
                expert_loss.backward()
                epoch_loss+=  agent_loss + expert_loss
                self.disc.optimizer.step()
            print("Epoch: {} Discriminator Loss: {}".format(i,epoch_loss/ctr))


    def get_ratio(self, x):
        x = torch.from_numpy(x).float()
        x =self.disc(x).detach().cpu().numpy()

        return (1/(x+1e-8))-1

class GridDensity():

    def __init__(self, range_lim, n=64):
        range_x, range_y = range_lim
        self.x_low = range_x[0] - 0.05
        self.x_high = range_x[1] + 0.005
        self.y_low = range_y[0] - 0.01
        self.y_high = range_y[1] + 0.01
        self.n = n
        self.dx = (self.x_high - self.x_low) / n
        self.dy = (self.y_high - self.y_low) / n
        print(self.x_low)
        print(self.x_high)
        print(self.y_low)
        print(self.y_high)


    def fit(self, states):
        count = np.zeros((self.n, self.n))
        
        x = states[:, 0] 
        y = states[:, 1] 

        x_indices = np.floor((x - self.x_low) / self.dx).astype(np.int32)
        y_indices = np.floor((y - self.y_low) / self.dy).astype(np.int32)

        for i in range(len(x_indices)):
            count[x_indices[i]][y_indices[i]] += 1

        count /= len(states)
        self.count = count

    def score_samples(self, states):
        probs = np.zeros(len(states))

        x = states[:, 0] 
        y = states[:, 1] 
        # assert np.all(np.logical_and(x >= self.x_low, x <= self.x_high))
        # assert np.all(np.logical_and(y >= self.y_low, y <= self.y_high))

        x_indices = np.floor((x - self.x_low) / self.dx).astype(np.int32)#[:, np.newaxis]
        y_indices = np.floor((y - self.y_low) / self.dy).astype(np.int32)#[:, np.newaxis]
    
        # indices = np.concatenate([x_indices, y_indices], axis=-1)

        for i in range(len(states)):
            # print(y[i])
            # print(y_indices[i])
            
            # print('x', x[i])
            # print(x_indices[i])
            probs[i] = self.count[x_indices[i]][y_indices[i]]

        return np.log(probs + 1e-8)
        
