import random
from copy import deepcopy
import numpy as np
import torch
from torch.optim import Adam
from numpy import linalg as LA
import gym
import d4rl
import argparse
import json
from utils import redirect_stdout
import torch.nn as nn
import itertools
import torch.nn.functional as F
from torch.distributions.normal import Normal

def load_dataset(env, traj_type, sample_type, pref_type):
    npzfile = np.load('../dataset/%s_%s_%s_%s.npz' %(env, traj_type, sample_type, pref_type))
    dataset = [npzfile['traj_obs'], npzfile['traj_act'], npzfile['traj_rew'], npzfile['traj_idx_1'], npzfile['traj_idx_2'], npzfile['pref'], traj_type]
    return dataset

def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers).to(torch.device('cuda'))

class action_nn(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes, act_limit = 1, activation = nn.ReLU):
        # assume act_limit to be 1
        super().__init__()
        pi_sizes = [obs_dim] + list(hidden_sizes) + [act_dim]
        self.pi = mlp(pi_sizes, activation, nn.Tanh)
        self.act_limit = (torch.as_tensor(act_limit, dtype=torch.float32)).to(torch.device('cuda'))
        self.device = torch.device('cuda')

    def forward(self, obs):
        return self.act_limit * self.pi(obs)

class action_model(object):
    def __init__(self, env_name, ac_kwargs=dict(), lr=1e-3):
        self.env_name = env_name
        self.env = gym.make(env_name)
        self.obs_dim = self.env.observation_space.shape[0]
        self.act_dim = self.env.action_space.shape[0]
        self.action_nn = action_nn(self.obs_dim, self.act_dim, **ac_kwargs)
        self.optimizer = Adam(self.action_nn.pi.parameters(), lr=lr, weight_decay=1e-5)

    def compute_loss(self, obs, act):
        obs = torch.as_tensor(obs[:-1], dtype=torch.float32).clone().detach().to(self.action_nn.device)
        act = torch.as_tensor(act, dtype=torch.float32).clone().detach().to(self.action_nn.device)
        predict_act = self.action_nn(obs)
        loss = (torch.norm(act - predict_act)**2).mean()
        return loss

    def update(self,loss):
        self.optimizer.zero_grad()
        loss.to(self.action_nn.device)
        loss.backward()
        self.optimizer.step()

def training(action_model, dataset, clone_type, epoch=100, epoch_steps=150, batch_size=64):
    total_num = len(dataset[3])
    training_num = int(total_num * 0.8)
    split = [i for i in range(total_num)]
    random.shuffle(split)
    training_set = split[:training_num]
    testing_set = split[training_num:]
    traj_obs, traj_act, _, traj_idx_1, traj_idx_2, prefs, traj_type = dataset

    for i in range(epoch):
        training_loss = 0
        for j in range(epoch_steps):
            batch_idx = random.sample(range(training_num),batch_size)
            loss = 0
            for idx in batch_idx:
                traj_obs1 = traj_obs[traj_idx_1[training_set[idx]]]
                traj_act1 = traj_act[traj_idx_1[training_set[idx]]]
                traj_obs2 = traj_obs[traj_idx_2[training_set[idx]]]
                traj_act2 = traj_act[traj_idx_2[training_set[idx]]]
                if clone_type == 'full':
                    loss += action_model.compute_loss(traj_obs1, traj_act1)
                    loss += action_model.compute_loss(traj_obs2, traj_act2)
                elif clone_type == 'better':
                    if prefs[training_set[idx]] == 1:
                        loss += action_model.compute_loss(traj_obs1, traj_act1)
                    else:
                        loss += action_model.compute_loss(traj_obs2, traj_act2)
                else:
                    raise Exception('behavior clone type undefined')
            action_model.update(loss)
            training_loss += loss.item()

        batch_idx = random.sample(range(total_num-training_num), 10 * batch_size)
        test_loss = 0
        for idx in batch_idx:
            traj_obs1 = traj_obs[traj_idx_1[testing_set[idx]]]
            traj_act1 = traj_act[traj_idx_1[testing_set[idx]]]
            traj_obs2 = traj_obs[traj_idx_2[testing_set[idx]]]
            traj_act2 = traj_act[traj_idx_2[testing_set[idx]]]
            if clone_type == 'full':
                test_loss += action_model.compute_loss(traj_obs1, traj_act1)
                test_loss += action_model.compute_loss(traj_obs2, traj_act2)
            elif clone_type == 'better':
                if prefs[training_set[idx]] == 1:
                    test_loss += action_model.compute_loss(traj_obs1, traj_act1)
                else:
                    test_loss += action_model.compute_loss(traj_obs2, traj_act2)
        test_loss = test_loss.item()
        print('epoch:', i, 'training_loss', training_loss/(epoch_steps * batch_size), 'test_loss:', test_loss/(10 * batch_size))

    torch.save(action_model.action_nn.state_dict(), '../action/%s_%s_%s' % (action_model.env_name, traj_type, clone_type))



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type = str, default = 'HalfCheetah-v2')
    parser.add_argument('--traj', type = str, default = 'expert')
    parser.add_argument('--sample', type = str, default = 'uniform')
    parser.add_argument('--pref', type=str, default='regular')
    parser.add_argument('--hid', type=int, default=64)
    parser.add_argument('--l', type=int, default=3)
    parser.add_argument('--clone', type=str, default='full')
    args = parser.parse_args()

    redirect_stdout(open('../tmp/action_%s_%s_%s' % (args.env, args.traj,args.clone), 'w'))

    action_model = action_model(env_name=args.env, ac_kwargs=dict(hidden_sizes=[args.hid] * args.l))
    dataset = load_dataset(args.env, args.traj, args.sample, args.pref)
    training(action_model, dataset, args.clone)




