import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import numpy as np
import gym
from torch.utils.data import Dataset, DataLoader
from utils import sample_trajectories
from env_utils import make_venv, get_env_demo_files
from tqdm import tqdm


class PairDataset(Dataset):
    def __init__(self, expert_demo_dir, env_name, spec=None,
                 nr_pairs=5000, traj_len=10, nr_traj=10):
        self.expert_demo_dir = expert_demo_dir
        self.nr_pairs = nr_pairs
        self.traj_len = traj_len
        self.nr_traj = nr_traj
        self.data = self.create_pair_buffer(env_name, spec)

    def create_pair_buffer(self, env_name, spec):
        demo_files = get_env_demo_files(self.expert_demo_dir, env_name, spec)

        expert_demos = {}
        for df in demo_files:
            expert_demos[df] = np.load(os.path.join(self.expert_demo_dir, df))

        pairs = []
        # sample pairs of files
        assert(len(demo_files) > 0)
        for idx in range(self.nr_pairs):
            i = np.random.randint(0, len(demo_files))
            j = np.random.randint(i, len(demo_files))
            expert_demos_i = expert_demos[demo_files[i]]
            expert_demos_j = expert_demos[demo_files[j]]
            traj_i = sample_trajectories(expert_demos_i, length=self.traj_len, nr_traj=self.nr_traj)
            traj_j = sample_trajectories(expert_demos_j, length=self.traj_len, nr_traj=self.nr_traj)
            pairs.append(np.concatenate([traj_i, traj_j], axis=-1))

        return pairs

    def __getitem__(self, index):
        x = self.data[index]
        return x

    def __len__(self):
        return len(self.data)


class PreferenceReward:
    def __init__(self, opt, bias=False):
        self.layer_dims = opt.d_layer_dims
        self.lr = opt.lr
        self.use_actions = opt.use_actions
        self.irm_coeff = opt.irm_coeff
        self.l2_coeff = 1e-3  
        ts = datetime.now().strftime('%Y%m%d_%H%M%S')

        envs, testing_env = make_venv(opt, 1, opt.env_kwargs[0], 
                opt.env_spec_test, 
                {}, use_subprocess=opt.use_subprocess,
                use_rank=opt.use_seed_ranking)

        ob_shape = list(testing_env.observation_space.shape)
        ac_shape = list(testing_env.action_space.shape)

        if not ac_shape:
            ac_shape = [1]

        if opt.use_actions:
            r_layer_dims = [ob_shape[-1] + ac_shape[-1]] + opt.d_layer_dims
        else:
            r_layer_dims = [ob_shape[-1]] + opt.d_layer_dims

        self.ob_shape = ob_shape
        self.ac_shape = ac_shape

        if self.irm_coeff > 0:
            irm_flag = 'irm'
        else:
            irm_flag = ''

        output_dir = 'exp_output/preference_' + irm_flag + '_' + ts

        self.summary_writer = SummaryWriter(output_dir)

        self.reward_layers = []
        for i in range(1, len(r_layer_dims)):
            self.reward_layers += [torch.nn.Linear(in_features=r_layer_dims[i - 1],
                                                   out_features=r_layer_dims[i],
                                                   bias=bias),
                                   torch.nn.Tanh()]

        self.reward_layers += [torch.nn.Linear(in_features=r_layer_dims[-1],
                                               out_features=1,
                                               bias=bias)]

        self.reward = nn.Sequential(*self.reward_layers)

        self.r_optimizer = Adam(self.reward.parameters(), lr=self.lr)

    def irm_penalty(self, logits, y):
        scale = torch.tensor(1.).requires_grad_()
        loss = F.binary_cross_entropy_with_logits(logits * scale, y)
        grad = autograd.grad(loss, [scale], create_graph=True)[0]
        return torch.sum(grad ** 2)

    def get_reward(self, ob, ac=None):
        if self.use_actions and ac is not None:
            return torch.squeeze(self.reward(torch.cat([ob, ac], axis=-1)))
        else:
            return torch.squeeze(self.reward(ob))

    def compute_loss(self, tau_i_batch, tau_j_batch):
        obs_i = tau_i_batch[:,:,:,:-self.ac_shape[0]] 
        obs_j = tau_j_batch[:,:,:,:-self.ac_shape[0]] 
        ac_i = tau_i_batch[:,:,:,-self.ac_shape[0]:] 
        ac_j = tau_j_batch[:,:,:,-self.ac_shape[0]:] 

        r_i = self.get_reward(obs_i, ac_i)
        r_i = torch.sum(r_i, dim=2)
        r_j = self.get_reward(obs_j, ac_j)
        r_j = torch.sum(r_j, dim=2)
        d_out = torch.cat([r_i[:, ::2] / (r_i[:, ::2] + r_j[:, ::2]),
                           r_j[:, 1::2] / (r_i[:, 1::2] + r_j[:, 1::2])], dim=1)

        weight_norm = torch.tensor(0.)
        for w in self.reward.parameters():
          weight_norm += w.norm().pow(2)
        # 0 if i<j, 1 if i>j
        labels = torch.cat([torch.zeros(r_i.shape[0], int(r_i.shape[1] / 2)),
                            torch.ones(r_i.shape[0], int(r_i.shape[1] / 2))], dim=1)

        bce_loss = F.binary_cross_entropy_with_logits(d_out, labels)
        grad_penalty = self.irm_penalty(d_out, labels)
        l2_loss = self.l2_coeff * weight_norm

        loss = bce_loss + self.irm_coeff * grad_penalty + l2_loss
        if self.irm_coeff > 1.0:
            loss /= self.irm_coeff

        return loss, bce_loss, grad_penalty, l2_loss

    def update(self, loss):
        self.r_optimizer.zero_grad()
        loss.backward()
        self.r_optimizer.step()

    def train(self, opt):
        dataloader_dict = {}
        if opt.env_kwargs is not None:
            for spec in opt.env_kwargs:
                pair_dataset = PairDataset(expert_demo_dir='./demos/preference_learning', 
                                        env_name=opt.env_name,
                                        spec=spec,
                                        nr_pairs=opt.nr_pairs)
                dataloader_dict[str(spec)] = DataLoader(pair_dataset, batch_size=opt.mini_batch_size)
        else:
            pair_dataset = PairDataset(expert_demo_dir='./demos/preference_learning', 
                                    env_name=opt.env_name,
                                    nr_pairs=opt.nr_pairs)
            dataloader_dict['all'] = DataLoader(pair_dataset, batch_size=opt.mini_batch_size)

        cnt = 0
        for ep in tqdm(range(opt.pref_epochs)):
            for spec in dataloader_dict:
                for idx, sample in tqdm(enumerate(dataloader_dict[spec])):
                    # split to get the two trajectories
                    tau_i_batch, tau_j_batch = torch.chunk(sample.type(torch.get_default_dtype()), 2, dim=-1)
                    # ignore the reward and done flag [:-2]
                    loss, bce_loss, grad_pen, l2_loss = self.compute_loss(tau_i_batch[:, :, :, :-2],
                                                                tau_j_batch[:, :, :, :-2])
                    self.update(loss)
                    cnt += 1
                    self.summary_writer.add_scalar('Losses/Total loss', loss, cnt)
                    self.summary_writer.add_scalar('Losses/BCE Loss', bce_loss, cnt)
                    self.summary_writer.add_scalar('Losses/IRM Loss', grad_pen, cnt)
                    self.summary_writer.add_scalar('Losses/L2 Loss', l2_loss, cnt)

        # save learned model at end of training
        if not os.path.exists('reward_models'):
            os.mkdir('reward_models')
        if not os.path.exists('reward_models/preference'):
            os.mkdir('reward_models/preference')

        ts = datetime.now().strftime('%Y%m%d_%H%M%S')
        torch.save(self.reward.state_dict(), 'reward_models/preference/' + opt.env_name + '_preference_reward_' + ts)
