import numpy as np
import gym
import random
from IPython import embed
import pickle
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
assert torch.cuda.is_available()

print("Using device: " + str(device))

class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden1 = torch.nn.Linear(n_feature, n_hidden)   # hidden layer
        self.predict = torch.nn.Linear(n_hidden, n_output)   # output layer

    def forward(self, x):
        x = F.relu(self.hidden1(x))      # activation function for hidden layer
        x = self.predict(x)             # linear output
        return x


class Alg():

    def __init__(self, dim, K, hidden):
        self.net = Net(n_feature=dim, n_hidden=hidden, n_output=K).to(device)     # define the network
        print(self.net)  
        self.dim = dim
        self.K = K
        self.hidden = hidden

        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=4e-3)
        self.loss_func = torch.nn.MSELoss()  



    def train(self, xs, actions, ys, epochs = 10 ):
        xs = torch.from_numpy(xs).float().to(device)
        actions = torch.from_numpy(actions).to(torch.int64).to(device)
        ys = ys.reshape((len(ys), 1))
        ys = torch.from_numpy(ys).float().to(device)

        dataset = TensorDataset(xs, actions, ys)
        BATCH_SIZE = 64
        dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)


        print("\tTraining... epochs " + str(epochs))
        for i in range(epochs):
            for i, (xs_batch, actions_batch, ys_batch) in enumerate(dataloader):

                one_hot = F.one_hot(actions_batch.view(-1), num_classes=self.K)
                prediction = (self.net.forward(xs_batch) * one_hot.float()).sum(
                dim=1, keepdim=True)


                loss = self.loss_func(prediction, ys_batch)     
                self.optimizer.zero_grad()   
                loss.backward()         
                self.optimizer.step()        

        print("\tLoss: " + str(loss.cpu().detach().numpy()))
        print("\tDone training.")


    def pred(self, xs, actions=None):
        xs = torch.from_numpy(xs).float().to(device)
        predictions = self.net(xs)
        res = predictions.cpu().detach().numpy()
        if actions is None:
            return res
        else:
            return res[(np.arange(len(actions)), actions)]


    def save(self, k):
        path = self.root + str(k) + '.pt'
        torch.save(self.net.state_dict(), path)
        torch.save(self.net.state_dict(), self.path)

    def load(self, k=-1):
        if k == -1:
            self.net.load_state_dict(torch.load(self.path))
        else:
            path = self.root + str(k) + '.pt'
            self.net.load_state_dict(torch.load(path))

class AlgContinuous(Alg):

    def __init__(self, dim, K_dim, hidden):
        self.net = Net(n_feature=dim + K_dim, n_hidden=hidden, n_output=1).to(device)     # define the network
        print(self.net)  
        self.dim = dim
        self.K_dim = K_dim
        self.hidden = hidden

        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=4e-3)
        self.loss_func = torch.nn.MSELoss()  

    def pred(self, xs, actions):

        xs = torch.from_numpy(xs).float().to(device)
        actions = torch.from_numpy(actions).float().to(device)
        inputs = torch.cat((xs.T, actions.T)).T
        predictions = self.net(inputs)
        res = predictions.cpu().detach().numpy()
        return res.flatten()

    def train(self, xs, actions, ys, epochs = 5 ):
        xs = torch.from_numpy(xs).float().to(device)
        actions = torch.from_numpy(actions).float().to(device)
        inputs = torch.cat((xs.T, actions.T)).T
        ys = ys.reshape((len(ys), 1))
        ys = torch.from_numpy(ys).float().to(device)

        dataset = TensorDataset(inputs, ys)
        BATCH_SIZE = 64
        dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)


        print("\tTraining... epochs " + str(epochs))
        for i in range(epochs):
            for i, (inputs_batch, ys_batch) in enumerate(dataloader):

                prediction = self.net.forward(inputs_batch)

                loss = self.loss_func(prediction, ys_batch)     
                self.optimizer.zero_grad()   
                loss.backward()         
                self.optimizer.step()        

        print("\tLoss: " + str(loss.cpu().detach().numpy()))
        print("\tDone training.")


