import torch
import os
from datetime import datetime
import numpy as np
import random
import pickle
import csv
from discrete_action_robots_modules.replay_buffer import replay_buffer
from discrete_action_robots_modules.models import ForwardMap, BackwardMap
from discrete_action_robots_modules.her import her_sampler
from discrete_action_robots_modules.normalizer import normalizer
from discrete_action_robots_modules.robots import goal_distance
from discrete_action_robots_modules.mdp_utils import extract_policy
from torch.distributions.cauchy import Cauchy

"""
FB agent with HER (MPI-version)

"""

class SamplingSeedActor(torch.nn.Module):
    def __init__(self, action_dim, z_dim, batch_size):
        super().__init__()
        self.z_dim = z_dim
        self.action_dim = action_dim
        self.powers = torch.tensor([2**i for i in range(self.z_dim)][::-1]).to('cuda').repeat(batch_size,1)
        self.max_seed = 2**z_dim+20000
        self.seed_to_action = []
        
        for i in range(self.max_seed):
            torch.random.manual_seed(i)
            action = torch.randint(0, self.action_dim, (1,)).unsqueeze(0).numpy()
            self.seed_to_action.append(action)
        self.seed_to_action = np.array(self.seed_to_action)
        self.seed_to_action = torch.tensor(self.seed_to_action).to('cuda')
    
    def forward(self, obs_hash, z):
        # import ipdb;ipdb.set_trace()
        actions = []
        seed_long = (z*self.powers).sum(1)
        # print("Time to compute z seed: ", time.time()-z_seed_time)
        final_seed = seed_long+obs_hash.reshape(-1)
        # print("Time to compute final seed: ", time.time()-final_seed_computation_time)
        # import ipdb;ipdb.set_trace()
        actions = self.seed_to_action[final_seed.long()]
        # print("Time to compute actions: ", time.time()-actions_computation_time)
        return torch.tensor(actions.reshape(-1,1)).to('cuda')

class LaplacianAgent:
    def __init__(self, args, env, env_params, buffer_path='./'):
        self.args = args
        self.env = env
        self.env_params = env_params
        self.cauchy = Cauchy(torch.tensor([0.0]), torch.tensor([0.5]))
        # create the network
        self.forward_network = ForwardMap(env_params, args.embed_dim)
        self.backward_network = BackwardMap(env_params, args.embed_dim)
        # build up the target network
        self.forward_target_network = ForwardMap(env_params, args.embed_dim)
        self.backward_target_network = BackwardMap(env_params, args.embed_dim)
        # load the weights into the target networks
        self.forward_target_network.load_state_dict(self.forward_network.state_dict())
        self.backward_target_network.load_state_dict(self.backward_network.state_dict())
        # if use gpu
        if self.args.cuda:
            self.forward_network.cuda()
            self.backward_network.cuda()
            self.forward_target_network.cuda()
            self.backward_target_network.cuda()
        # create the optimizer
        f_params = [param for param in self.forward_network.parameters()]
        b_params = [param for param in self.backward_network.parameters()]
        self.fb_optim = torch.optim.Adam(f_params + b_params, lr=self.args.lr)
        # self.backward_optim = torch.optim.Adam(self.backward_network.parameters(), lr=self.args.lr_backward)
        # her sampler
        self.her_module = her_sampler(self.args.replay_strategy, self.args.replay_k, self.env.compute_reward)
        # create the replay buffer
        self.buffer = replay_buffer(self.env_params, self.args.buffer_size, self.her_module.sample_her_transitions)

        self.o_norm = normalizer(size=env_params['obs'], default_clip_range=self.args.clip_range)
        self.g_norm = normalizer(size=env_params['goal'], default_clip_range=self.args.clip_range)
        self.save_dir = f'{args.save_dir}/LAPLACE/seed-{args.seed}/{datetime.now().strftime("%Y%m%d-%H%M%S")}'
        if args.save_dir is not None:
            # create the dict for store the model
            if not os.path.exists(self.save_dir):
                os.makedirs(self.save_dir)

            print(' ' * 26 + 'Options')
            for k, v in vars(self.args).items():
                print(' ' * 26 + k + ': ' + str(v))

            with open(self.save_dir + "/arguments.pkl", 'wb') as f:
                pickle.dump(self.args, f)

            with open('{}/score_monitor.csv'.format(self.save_dir), "wt") as monitor_file:
                monitor = csv.writer(monitor_file)
                monitor.writerow(['epoch', 'eval', 'avg dist', 'eval (GPI)', 'avg dist (GPI)'])

        self.buffer_path = buffer_path
        self.buffer.load(os.path.join(self.buffer_path, 'fetch_reach_buffer.pkl'))
        # self.update_normalizer()
        self.obs_mean = self.buffer.buffers['obs'].reshape(-1, self.env_params['obs']).mean(0)
        self.obs_std = self.buffer.buffers['obs'].reshape(-1, self.env_params['obs']).std(0)
        self.g_mean = self.buffer.buffers['g'].reshape(-1, self.env_params['goal']).mean(0)
        self.g_std = self.buffer.buffers['g'].reshape(-1, self.env_params['goal']).std(0)

    def learn(self):
        """
        train the network

        """
        # start to collect samples
        # print('MPI SIZE: ', MPI.COMM_WORLD.Get_size())
        for epoch in range(self.args.n_epochs):
            for _ in range(self.args.n_cycles):
                # mb_obs, mb_ag, mb_g, mb_actions = [], [], [], []
                # for _ in range(self.args.num_rollouts_per_cycle):
                #     # reset the rollouts
                #     ep_obs, ep_ag, ep_g, ep_actions = [], [], [], []
                #     # reset the environment
                #     observation = self.env.reset()
                #     obs = observation['observation']
                #     ag = observation['achieved_goal']
                #     g = observation['desired_goal']
                #     if self.args.w_sampling == 'goal_oriented':
                #         g_tensor = self._preproc_g(g)
                #         with torch.no_grad():
                #             w = self.backward_network(g_tensor)
                #     elif self.args.w_sampling == 'uniform_ball':
                #         w = self.sample_uniform_ball(1)
                #     elif self.args.w_sampling == 'cauchy_ball':
                #         w = self.sample_cauchy_ball(1)

                #     # start to collect samples
                #     for t in range(self.env_params['max_timesteps']):
                #         with torch.no_grad():
                #             obs_tensor = self._preproc_o(obs)
                #             action = self.act_e_greedy(obs_tensor, w, update_eps=0.2)
                #         # feed the actions into the environment
                #         observation_new, _, _, info = self.env.step(action)
                #         obs_new = observation_new['observation']
                #         ag_new = observation_new['achieved_goal']
                #         # append rollouts
                #         ep_obs.append(obs.copy())
                #         ep_ag.append(ag.copy())
                #         ep_g.append(g.copy())
                #         ep_actions.append(action)
                #         # re-assign the observation
                #         obs = obs_new
                #         ag = ag_new
                #     ep_obs.append(obs.copy())
                #     ep_ag.append(ag.copy())
                #     mb_obs.append(ep_obs)
                #     mb_ag.append(ep_ag)
                #     mb_g.append(ep_g)
                #     mb_actions.append(ep_actions)
                # # convert them into arrays
                # mb_obs = np.array(mb_obs)
                # mb_ag = np.array(mb_ag)
                # mb_g = np.array(mb_g)
                # mb_actions = np.array(mb_actions)
                # # store the episodes
                # self.buffer.store_episode([mb_obs, mb_ag, mb_g, mb_actions])
                # update normalizer statistics
                # self._update_normalizer([mb_obs, mb_ag, mb_g, mb_actions])
                for _ in range(self.args.n_batches):
                    # train the network
                    self._update_network()
                # soft update
                self._soft_update_target_network(self.forward_target_network, self.forward_network)
                self._soft_update_target_network(self.backward_target_network, self.backward_network)
            # start to do the evaluation
            success_rate, avg_dist = self._eval_agent()
            success_rate_gpi, avg_dist_gpi = self._eval_gpi_agent(num_gpi=self.args.num_gpi)
            print('[{}] epoch is: {}, eval: {:.3f}, avg_dist : {:.3f}, '
                  'eval (GPI): {:.3f}, avg_dist (GPI): {:.3f}'.format(datetime.now(), epoch, success_rate, avg_dist,
                                                                      success_rate_gpi, avg_dist_gpi))
            with open('{}/score_monitor.csv'.format(self.save_dir), "a") as monitor_file:
                monitor = csv.writer(monitor_file)
                monitor.writerow([epoch, success_rate, avg_dist, success_rate_gpi, avg_dist_gpi])
            torch.save([self.forward_network.state_dict(), self.backward_network.state_dict()],
                       os.path.join(self.save_dir, 'model.pt'))

        # self.buffer.save(os.path.join(self.buffer_path, 'fetch_reach_buffer.pkl'))

    def sample_uniform_ball(self, n, eps=1e-10):
        gaussian_rdv = torch.FloatTensor(n, self.args.embed_dim).normal_(mean=0, std=1)
        gaussian_rdv /= torch.norm(gaussian_rdv, dim=-1, keepdim=True) + eps
        uniform_rdv = torch.FloatTensor(n, 1).uniform_()
        w = np.sqrt(self.args.embed_dim) * gaussian_rdv * uniform_rdv
        if self.args.cuda:
            w = w.cuda()
        return w

    def sample_cauchy_ball(self, n, eps=1e-10):
        gaussian_rdv = torch.FloatTensor(n, self.args.embed_dim).normal_(mean=0, std=1)
        gaussian_rdv /= torch.norm(gaussian_rdv, dim=-1, keepdim=True) + eps
        cauchy_rdv = self.cauchy.sample((n, ))
        w = np.sqrt(self.args.embed_dim) * gaussian_rdv * cauchy_rdv
        if self.args.cuda:
            w = w.cuda()
        return w

    # pre_process the inputs
    def _preproc_o(self, obs):
        # obs = self._clip(obs)
        # obs_norm = self.o_norm.normalize(obs)
        obs_norm = (obs - self.obs_mean) / (self.obs_std + 1e-6)
        obs_tensor = torch.tensor(obs_norm, dtype=torch.float32).unsqueeze(0)
        if self.args.cuda:
            obs_tensor = obs_tensor.cuda()
        return obs_tensor

    def _preproc_g(self, g):
        # g = self._clip(g)
        # g_norm = self.g_norm.normalize(g)
        g_norm = (g - self.g_mean) / (self.g_std + 1e-6)
        g_tensor = torch.tensor(g_norm, dtype=torch.float32).unsqueeze(0)
        if self.args.cuda:
            g_tensor = g_tensor.cuda()
        return g_tensor

    def update_normalizer(self):
        print('Buffer size: ', self.buffer.size)
        for i in range(0, self.buffer.size, self.args.batch_size):
            if i + self.args.batch_size > self.buffer.size:
                ep_obs = self.buffer.buffers['obs'][i:]
                ep_ag = self.buffer.buffers['ag'][i:]
                ep_g = self.buffer.buffers['g'][i:]
                ep_actions = self.buffer.buffers['actions'][i:]
            else:
                ep_obs = self.buffer.buffers['obs'][i:i+self.args.batch_size]
                ep_ag = self.buffer.buffers['ag'][i:i+self.args.batch_size]
                ep_g = self.buffer.buffers['g'][i:i+self.args.batch_size]
                ep_actions = self.buffer.buffers['actions'][i:i+self.args.batch_size]

            self._update_normalizer([ep_obs, ep_ag, ep_g, ep_actions])
        
    # update the normalizer
    def _update_normalizer(self, episode_batch):
        mb_obs, mb_ag, mb_g, mb_actions = episode_batch
        mb_obs_next = mb_obs[:, 1:, :]
        mb_ag_next = mb_ag[:, 1:, :]
        # get the number of normalization transitions
        num_transitions = mb_actions.shape[1]
        # create the new buffer to store them
        buffer_temp = {'obs': mb_obs,
                       'ag': mb_ag,
                       'g': mb_g,
                       'actions': mb_actions,
                       'obs_next': mb_obs_next,
                       'ag_next': mb_ag_next,
                       }
        # print('Buffer_temp:  ', buffer_temp['obs'].shape)
        # print(num_transitions)
        transitions = self.her_module.sample_her_transitions(buffer_temp, num_transitions)
        obs, g = transitions['obs'], transitions['ag']  # replace g by ag
        # pre process the obs and g
        transitions['obs'], transitions['g'] = self._clip(obs), self._clip(g)
        # update
        self.o_norm.update(transitions['obs'])
        self.g_norm.update(transitions['g'])
        # recompute the stats
        self.o_norm.recompute_stats()
        self.g_norm.recompute_stats()

    def act_gpi(self, obs, w_train, w_eval):
        # import pdb
        # pdb.set_trace()
        num_gpi = w_train.shape[0]
        obs_repeat = obs.repeat(num_gpi, 1)
        w_eval_repeat = w_eval.repeat(num_gpi, 1)
        f = self.forward_network(obs_repeat, w_train)
        z = torch.einsum('sda, sd -> sa', f, w_eval_repeat).max(0)[0]
        return z.max(0)[1]

    # Acts based on single state (no batch)
    def act(self, obs, w, target_network=False):
        if target_network:
            f = self.forward_target_network(obs, w)
        else:
            f = self.forward_network(obs, w)
        z = torch.einsum('sda, sd -> sa', f, w)
        return z.max(1)[1]

    def get_policy(self, obs, w, policy_type='boltzmann', temp=1, eps=0.01, target_network=False):
        if target_network:
            f = self.forward_target_network(obs, w)
        else:
            f = self.forward_network(obs, w)
        z = torch.einsum('sda, sd -> sa', f, w)
        return extract_policy(z, policy_type=policy_type, temp=temp, eps=eps)

    # Acts with an epsilon-greedy policy
    def act_e_greedy(self, obs, g, update_eps=0.2):
        return random.randrange(self.env_params['action']) if random.random() < update_eps else self.act(obs, g).item()

    def _clip(self, o):
        o = np.clip(o, -self.args.clip_obs, self.args.clip_obs)
        return o

    # soft update
    def _soft_update_target_network(self, target, source):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_((1 - self.args.polyak) * param.data + self.args.polyak * target_param.data)

    # update the network
    def _update_network(self):
        # sample the episodes
        transitions = self.buffer.sample(self.args.batch_size)
        other_transitions = self.buffer.sample(self.args.batch_size)
        # pre-process the observation and goal
        o, o_next, g, ag, actions, ag_next = transitions['obs'], transitions['obs_next'], transitions['g'], transitions['ag'], transitions['actions'], transitions['ag_next']
        # self._update_normalizer([o, o_next, g, actions])
        # transitions['obs'], transitions['g'] = self.o_norm.normalize(o)\
        #     , self.g_norm.normalize(g)
        # transitions['obs_next'] = self.o_norm.normalize(o_next)
        # transitions['ag'] = self.g_norm.normalize(ag)
        # transitions['ag_next'] = self.g_norm.normalize(ag_next)
        # other_transitions['ag'] = self.g_norm.normalize(other_transitions['ag'])

        transitions['obs'] = (o - self.obs_mean) / (self.obs_std + 1e-6)
        transitions['g'] = (g - self.g_mean) / (self.g_std + 1e-6)
        transitions['obs_next'] = (o_next - self.obs_mean) / (self.obs_std + 1e-6)
        transitions['ag'] = (ag - self.g_mean) / (self.g_std + 1e-6)
        transitions['ag_next'] = (ag_next - self.g_mean) / (self.g_std + 1e-6)
        other_transitions['ag'] = (other_transitions['ag'] - self.g_mean) / (self.g_std + 1e-6)
        # other_ag = transitions['g']

        # transfer them into the tensor
        obs_tensor = torch.tensor(transitions['obs'], dtype=torch.float32)
        g_tensor = torch.tensor(transitions['g'], dtype=torch.float32)
        obs_next_tensor = torch.tensor(transitions['obs_next'], dtype=torch.float32)
        actions_tensor = torch.tensor(transitions['actions'], dtype=torch.long)
        ag_tensor = torch.tensor(transitions['ag'], dtype=torch.float32)
        ag_next_tensor = torch.tensor(transitions['ag_next'], dtype=torch.float32)
        # ag_other_tensor = torch.tensor(other_ag, dtype=torch.float32)
        ag_other_tensor = torch.tensor(other_transitions['ag'], dtype=torch.float32)
        if self.args.cuda:
            obs_tensor = obs_tensor.cuda()
            g_tensor = g_tensor.cuda()
            obs_next_tensor = obs_next_tensor.cuda()
            actions_tensor = actions_tensor.cuda()
            ag_tensor = ag_tensor.cuda()
            ag_other_tensor = ag_other_tensor.cuda()
            ag_next_tensor = ag_next_tensor.cuda()

        if self.args.w_sampling == 'goal_oriented':
            with torch.no_grad():
                w = self.backward_network(g_tensor)
                w = w.detach()
        elif self.args.w_sampling == 'uniform_ball':
            w = self.sample_uniform_ball(self.args.batch_size)
        elif self.args.w_sampling == 'cauchy_ball':
            w = self.sample_cauchy_ball(self.args.batch_size)

        # calculate the target Q value function
        with torch.no_grad():
            if self.args.soft_update:
                pi = self.get_policy(obs_next_tensor, w, policy_type='boltzmann', temp=self.args.temp,
                                     target_network=True)
                f_next = torch.einsum('sda, sa -> sd', self.forward_target_network(obs_next_tensor, w), pi)
            else:
                actions_next_tensor = self.act(obs_next_tensor, w, target_network=True)
                next_idxs = actions_next_tensor[:, None].repeat(1, self.args.embed_dim)[:, :, None]
                f_next = self.forward_target_network(obs_next_tensor, w).gather(-1, next_idxs).squeeze()  # batch x dim


            b_next = self.backward_target_network(ag_other_tensor)  # batch x dim
            z_next = torch.einsum('sd, td -> st', f_next, b_next)  # batch x batch
            z_next = z_next.detach()
            # # clip the q value
            # clip_return = 1 / (1 - self.args.gamma)
            # target_q_value = torch.clamp(target_q_value, -clip_return, 0)
        # the forward loss
        idxs = actions_tensor[:, None].repeat(1, self.args.embed_dim)[:, :, None]
        f = self.forward_network(obs_tensor, w).gather(-1, idxs).squeeze()
        b = self.backward_network(ag_tensor).detach()
        b_other = self.backward_network(ag_other_tensor).detach()
        z_diag = torch.einsum('sd, sd -> s', f, b)  # batch
        z = torch.einsum('sd, td -> st', f, b_other)  # batch x batch
        # Laplacian loss
        b_obs = self.backward_network(ag_tensor)
        b_next_obs = self.backward_target_network(ag_next_tensor)
        b_loss = ((b_obs-b_next_obs)**2).mean()

        fb_loss = 0.5 * (z - self.args.gamma * z_next).pow(2).mean() - z_diag.mean() + b_loss
    
        

        # compute orthonormality's regularisation loss
        b_b_other = torch.einsum('sd, xd -> sx', b, b_other)  # batch x batch
        b_b_other_detach = torch.einsum('sd, xd -> sx', b, b_other.detach())  # batch x batch
        b_b_detach = torch.einsum('sd, sd -> s', b, b.detach())  # batch
        reg_loss = (b_b_detach * b_b_other.detach()).mean() - b_b_other_detach.mean()
        fb_loss += self.args.reg_coef * reg_loss

        # update the forward_network
        self.fb_optim.zero_grad()
        fb_loss.backward()
        self.fb_optim.step()

        # the backward loss
        # f = self.forward_network(obs_norm_tensor, actions_tensor, w)
        # b = self.backward_network(ag_norm_tensor)
        # b_other = self.backward_network(g_other_norm_tensor)
        # z_diag = torch.einsum('sd, sd -> s', f, b)  # batch
        # z = torch.einsum('sd, td -> st', f, b_other)  # batch x batch
        # b_loss = 0.5 * (z - self.args.gamma * z_next).pow(2).mean() - z_diag.mean()
        # compute orthonormality's regularisation loss
        # b_b_other = torch.einsum('sd, xd -> sx', b, b_other)  # batch x batch
        # b_b_other_detach = torch.einsum('sd, xd -> sx', b, b_other.detach())  # batch x batch
        # b_b_detach = torch.einsum('sd, sd -> s', b, b.detach())  # batch
        # reg_loss = (b_b_detach * b_b_other.detach()).mean() - b_b_other_detach.mean()
        # b_loss += self.args.reg_coef * reg_loss
        #
        # # update the backward_network
        # self.backward_optim.zero_grad()
        # b_loss.backward()
        # sync_grads(self.backward_network)
        # self.backward_optim.step()

        # print('f_loss: {}, b_loss: {}'.format(f_loss.item(), b_loss.item()))

    # do the evaluation
    def _eval_agent(self):
        total_success_rate = []
        total_dist = []
        for _ in range(self.args.n_test_rollouts):
            per_success_rate = []
            per_dist = []
            observation = self.env.reset()
            obs = observation['observation']
            g = observation['desired_goal']

            # for _ in range(self.env_params['max_timesteps']):
            for _ in range(25):
                with torch.no_grad():
                    g_tensor = self._preproc_g(g)
                    w = self.backward_network(g_tensor)
                    obs_tensor = self._preproc_o(obs)
                    action = self.act(obs_tensor, w).item()
                observation_new, _, _, info = self.env.step(action)
                obs = observation_new['observation']
                g = observation_new['desired_goal']
                dist = goal_distance(observation_new['achieved_goal'], observation_new['desired_goal'])
                # per_dist.append(dist)
                # per_success_rate.append(info['is_success'])
                per_dist = dist
                per_success_rate = info['is_success']
                if info['is_success'] > 0:
                    break
            total_success_rate.append(per_success_rate)
            total_dist.append(per_dist)
        total_success_rate = np.array(total_success_rate)
        avg_success_rate = np.mean(total_success_rate)
        total_dist = np.array(total_dist)
        avg_dist = np.mean(total_dist)
        return avg_success_rate, avg_dist

    def _eval_gpi_agent(self, num_gpi=20):
        total_success_rate = []
        total_dist = []
        for _ in range(self.args.n_test_rollouts):
            per_success_rate = []
            per_dist = []
            observation = self.env.reset()
            obs = observation['observation']
            g = observation['desired_goal']
            if self.args.w_sampling == 'goal_oriented':
                transitions = self.buffer.sample(num_gpi)
                g_train = transitions['g']
                g_train_tensor = torch.tensor(g_train, dtype=torch.float32)
                if self.args.cuda:
                    g_train_tensor = g_train_tensor.cuda()
                w_train = self.backward_network(g_train_tensor)
            elif self.args.w_sampling == 'uniform_ball':
                w_train = self.sample_uniform_ball(num_gpi)
            elif self.args.w_sampling == 'cauchy_ball':
                w_train = self.sample_cauchy_ball(num_gpi)

            # for _ in range(self.env_params['max_timesteps']):
            for _ in range(25):
                with torch.no_grad():
                    g_tensor = self._preproc_g(g)
                    w = self.backward_network(g_tensor)
                    obs_tensor = self._preproc_o(obs)
                    action = self.act_gpi(obs_tensor, w_train, w).item()
                observation_new, _, _, info = self.env.step(action)
                obs = observation_new['observation']
                g = observation_new['desired_goal']
                dist = goal_distance(observation_new['achieved_goal'], observation_new['desired_goal'])
                # per_dist.append(dist)
                # per_success_rate.append(info['is_success'])
                per_dist = dist
                per_success_rate = info['is_success']
                if info['is_success'] > 0:
                    break
            total_success_rate.append(per_success_rate)
            total_dist.append(per_dist)
        total_success_rate = np.array(total_success_rate)
        avg_success_rate = np.mean(total_success_rate)
        total_dist = np.array(total_dist)
        avg_dist = np.mean(total_dist)
        return avg_success_rate, avg_dist
