import pickle as pkl
import torch
import torch.nn as nn
import gym
import random
import torch.nn.functional as F
from network import BC
import argparse

def parser_args():    
    parser = argparse.ArgumentParser(description="Behavioral cloning on classic environment in Open AI Gym")
    
    parser.add_argument("-n", "--name", type=str,
                        help="name of environment")
    parser.add_argument("-c", "--cuda_device", type=int, default=0,
                        help="cuda device number, default 0")
    parser.add_argument("-e", "--epoch", type=int, default=200, 
                        help="number of iterations")
    parser.add_argument("-l", "--lr", type=float, default=1e-3, 
                        help="learning rate")
    
    args = parser.parse_args()
    return args 

if __name__=='__main__':
    args = parser_args()
    print(args)
    device = torch.device('cuda:{}'.format(args.cuda_device))
    
    with open('./rl_baselines_zoo/experts/{}_expert_demo.pkl'.format(args.name), 'rb') as f:
        obs_expert, actions_expert, rewards_expert, next_obs_expert, dones_expert = pkl.load(f)
    obs_expert = [torch.from_numpy(obs).float() for obs in obs_expert]
    obs_expert = torch.cat(obs_expert)
    actions_expert = torch.tensor(actions_expert)

    num_pairs = len(obs_expert)
    train_idx = random.sample(range(num_pairs), int(0.8*num_pairs))
    val_idx = [i for i in range(num_pairs) if i not in train_idx]

    x_train = obs_expert[train_idx]
    x_val = obs_expert[val_idx]
    y_train = actions_expert[train_idx]
    y_val = actions_expert[val_idx]
    
    env = gym.make(args.name)
    
    model = BC(env.observation_space.shape[0], 128, env.action_space.n).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    
    trainset = torch.utils.data.TensorDataset(x_train, y_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
    valset = torch.utils.data.TensorDataset(x_val, y_val)
    valloader = torch.utils.data.DataLoader(valset, batch_size=128)
    
    loss_list = []
    best_acc = 0
    for i in range(1, args.epoch+1):
        model.train()
        for x, y in trainloader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = F.cross_entropy(out, y)
            print("###################Iteration : {}, Loss : {}####################".format(i+1, 
                                                                                            loss.item()), end='\r')
            loss_list.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        model.eval()
        correct = 0
        total = 0
        for x, y in valloader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            out = out.argmax(dim=1)
            total += out.size(0)
            correct += (out==y).sum().item()
        print("###################Iteration : {}, Val acc: {}####################".format(i, correct/total))
        if best_acc < (correct/total):
            best_acc = correct/total
            torch.save(model.state_dict(), 'models/{}_BC.bin'.format(args.name))
        
    print("Best validation accuracy : {}".format(best_acc))
    
    
    