import random
from copy import deepcopy
import numpy as np
import torch
from torch.optim import Adam
from numpy import linalg as LA
import gym
import d4rl
import argparse
import json
from utils import redirect_stdout, TrainLogger
import torch.nn as nn
import itertools
import torch.nn.functional as F
from torch.distributions.normal import Normal
from tqdm import tqdm


def load_dataset(env, traj_type, sample_type, pref_type):
    npzfile = np.load('../dataset/%s_%s_%s_%s.npz' %(env, traj_type, sample_type, pref_type))
    dataset = [npzfile['traj_obs'], npzfile['traj_act'], npzfile['traj_rew'], npzfile['traj_idx_1'], npzfile['traj_idx_2'], npzfile['pref'], traj_type]
    return dataset

def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers).to(torch.device('cuda'))

class action_nn(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes, act_limit = 1, activation = nn.ReLU):
        # assume act_limit to be 1
        super().__init__()
        pi_sizes = [obs_dim] + list(hidden_sizes) + [act_dim]
        self.pi = mlp(pi_sizes, activation, nn.Tanh)
        self.act_limit = (torch.as_tensor(act_limit, dtype=torch.float32)).to(torch.device('cuda'))
        self.device = torch.device('cuda')

    def forward(self, obs):
        return self.act_limit * self.pi(obs)

class behavior_model(object):
    def __init__(self, env_name, ac_kwargs=dict(), lr=1e-3):
        self.env_name = env_name
        self.env = gym.make(env_name)
        self.obs_dim = self.env.observation_space.shape[0]
        self.act_dim = self.env.action_space.shape[0]
        self.action_nn = action_nn(self.obs_dim, self.act_dim, **ac_kwargs)
        self.optimizer = Adam(self.action_nn.pi.parameters(), lr=lr, weight_decay=1e-5)

    def compute_loss(self, obs, act):
        obs = torch.as_tensor(obs[:-1], dtype=torch.float32).clone().detach().to(self.action_nn.device)
        act = torch.as_tensor(act, dtype=torch.float32).clone().detach().to(self.action_nn.device)
        predict_act = self.action_nn(obs)
        loss = (torch.norm(act - predict_act)**2).mean()
        return loss

    def update(self,loss):
        self.optimizer.zero_grad()
        loss.to(self.action_nn.device)
        loss.backward()
        self.optimizer.step()

    def test(self):
        with torch.no_grad():
            ep_rets = []
            for j in range(20):
                o, d, ep_ret, ep_len = self.env.reset(), False, 0, 0
                while not (d or (ep_len == 1000)):
                    # Take deterministic actions at test time
                    a = self.action_nn(torch.as_tensor(o, dtype=torch.float32).to(self.action_nn.device))
                    o, r, d, _ = self.env.step(a.cpu().numpy())
                    ep_ret += r
                    ep_len += 1
                ep_rets.append(ep_ret)
            print(sum(ep_rets)/len(ep_rets))
            return (sum(ep_rets)/len(ep_rets))

def training(behavior_model, dataset, clone_type = 'full', epoch=30, epoch_steps=100, batch_size=64, behavior_path=None, logger = TrainLogger()):
    total_num = len(dataset[3])
    training_num = int(total_num * 0.8)
    split = [i for i in range(total_num)]
    random.shuffle(split)
    training_set = split[:training_num]
    testing_set = split[training_num:]
    traj_obs, traj_act, _, traj_idx_1, traj_idx_2, prefs = dataset

    best_test_loss = float('inf')
    #best_test_behavior = float('-inf')
    for i in tqdm(range(epoch), desc = 'Behavior Num Epochs'):
        training_loss = 0
        for j in tqdm(range(epoch_steps), desc = 'Num Steps', colour = 'blue'):
            batch_idx = random.sample(range(training_num),batch_size)
            loss = 0
            for idx in batch_idx:
                traj_obs1 = traj_obs[traj_idx_1[training_set[idx]]]
                traj_act1 = traj_act[traj_idx_1[training_set[idx]]]
                traj_obs2 = traj_obs[traj_idx_2[training_set[idx]]]
                traj_act2 = traj_act[traj_idx_2[training_set[idx]]]
                if clone_type == 'full':
                    loss += behavior_model.compute_loss(traj_obs1, traj_act1)
                    loss += behavior_model.compute_loss(traj_obs2, traj_act2)
                elif clone_type == 'better':
                    if prefs[training_set[idx]] == 1:
                        loss += behavior_model.compute_loss(traj_obs1, traj_act1)
                    else:
                        loss += behavior_model.compute_loss(traj_obs2, traj_act2)
                else:
                    raise Exception('behavior clone type undefined')
            behavior_model.update(loss)
            training_loss += loss.item()

        with torch.no_grad():
            batch_idx = random.sample(range(total_num-training_num), min(20 * batch_size, total_num-training_num))
            test_loss = 0
            for idx in batch_idx:
                traj_obs1 = traj_obs[traj_idx_1[testing_set[idx]]]
                traj_act1 = traj_act[traj_idx_1[testing_set[idx]]]
                traj_obs2 = traj_obs[traj_idx_2[testing_set[idx]]]
                traj_act2 = traj_act[traj_idx_2[testing_set[idx]]]
                if clone_type == 'full':
                    test_loss += behavior_model.compute_loss(traj_obs1, traj_act1)
                    test_loss += behavior_model.compute_loss(traj_obs2, traj_act2)
                elif clone_type == 'better':
                    if prefs[training_set[idx]] == 1:
                        test_loss += behavior_model.compute_loss(traj_obs1, traj_act1)
                    else:
                        test_loss += behavior_model.compute_loss(traj_obs2, traj_act2)
            test_loss = test_loss.item()
        logger.log({'epoch': i, 'behavior_training_loss': training_loss/(epoch_steps * batch_size), 'behavior_test_loss': test_loss/(20 * batch_size)})
        # test_behavior = behavior_model.test()
        # logger.log({'test behavior': test_behavior})
        if test_loss < best_test_loss:
            best_test_loss = test_loss
            if i >= 10 and behavior_path is not None:
                torch.save(behavior_model.action_nn.state_dict(), behavior_path)

LOG_STD_MAX = 2
LOG_STD_MIN = -20

class stochastic_action_nn(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation = nn.ReLU):
        super().__init__()
        self.net = mlp([obs_dim] + list(hidden_sizes), activation, activation)
        self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim).to(torch.device('cuda'))
        self.log_std_layer = nn.Linear(hidden_sizes[-1], act_dim).to(torch.device('cuda'))
        self.device = torch.device('cuda')

    def forward(self, obs, act):
        obs = obs.clone().detach().to(torch.device('cuda'))
        act = act.clone().detach().to(torch.device('cuda'))
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = torch.exp(log_std)

        # Pre-squash distribution and sample
        pi_distribution = Normal(mu, std)

        log_prob = pi_distribution.log_prob(act).sum(axis=-1)

        # act_expand = torch.atanh(act)
        # log_prob = pi_distribution.log_prob(act_expand).sum(axis=-1)
        # log_prob -= (2*(np.log(2) - act_expand - F.softplus(-2*act_expand))).sum(axis=1)
        # log_prob.to(torch.device('cuda'))
        # print(log_prob, log_prob.mean())
        return log_prob



class stochastic_behavior_model(object):
    def __init__(self, env_name, ac_kwargs=dict(), lr=1e-3):
        self.env_name = env_name
        self.env = gym.make(env_name)
        self.obs_dim = self.env.observation_space.shape[0]
        self.act_dim = self.env.action_space.shape[0]
        self.action_nn = stochastic_action_nn(self.obs_dim, self.act_dim, **ac_kwargs)
        self.optimizer = Adam(self.action_nn.parameters(), lr=lr, weight_decay=1e-5)

    def compute_loss(self, obs, act):
        obs = torch.as_tensor(obs[:-1], dtype=torch.float32).clone().detach().to(self.action_nn.device)
        act = torch.as_tensor(act, dtype=torch.float32).clone().detach().to(self.action_nn.device)
        log_prob = self.action_nn(obs, act)
        loss = -log_prob.mean()
        return loss

    def update(self,loss):
        self.optimizer.zero_grad()
        loss.to(self.action_nn.device)
        loss.backward()
        self.optimizer.step()

    def test(self):
        with torch.no_grad():
            ep_rets = []
            for j in range(5):
                o, d, ep_ret, ep_len = self.env.reset(), False, 0, 0
                while not (d or (ep_len == 1000)):
                    # Take deterministic actions at test time
                    a = self.action_nn.mu_layer(self.action_nn.net(torch.as_tensor(o, dtype=torch.float32).to(self.action_nn.device)))
                    o, r, d, _ = self.env.step(a.cpu().numpy())
                    ep_ret += r
                    ep_len += 1
                ep_rets.append(ep_ret)
            print(sum(ep_rets)/len(ep_rets))
