from rlf import BaseIRLAlgo
import rlf.rl.utils as rutils
from rlf.il import TrajDataset
import numpy as np
import rlf.il.utils as iutils
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
import torch.nn as nn
import torch
import torch.optim as optim
from rlf.algos.nested_algo import NestedAlgo
from rlf.algos.on_policy.ppo import PPO
from collections import deque, defaultdict
from tqdm import tqdm
import torch.nn.functional as F
import goal_prox.method.utils as mutils
from goal_prox.envs.gw_helper import *
from rlf.args import str2bool

def generate_pairs(states, dist_norm_factor, sample_factor):
    traj_idx = list(range(len(states)))

    max_iters = len(states) ** 2
    min_iters = len(states)
    n_iters = int((max_iters - min_iters) * sample_factor) + min_iters
    for _ in range(n_iters):
        choosen_pair = np.random.choice(traj_idx, 2)
        # Always be in increasing time order.
        choosen_pair[0] = min(*choosen_pair)
        choosen_pair[1] = max(*choosen_pair)

        state_1 = states[choosen_pair[0]]
        state_2 = states[choosen_pair[1]]
        dist = (choosen_pair[1] - choosen_pair[0]) / dist_norm_factor
        assert state_1.shape[0] == 4
        yield {
                'state_1': state_1,
                'state_2': state_2,
                'dist': torch.tensor(dist),
                }

class PairsDataset(TrajDataset):
    def __init__(self, load_path, dist_norm_factor, sample_factor):
        self.dist_norm_factor = dist_norm_factor
        self.sample_factor = sample_factor
        super().__init__(load_path)

    def _gen_data(self, trajs):
        data = []
        # Select random pairs from each trajectory.
        for states, actions in trajs:
            data.extend(generate_pairs(states, self.dist_norm_factor,
                self.sample_factor))
        return data

    def __getitem__(self, i):
        return self.data[i]


class DistFunc(nn.Module):
    def __init__(self, base_encoder, hidden_dim=64):
        super().__init__()
        self.base_net = base_encoder.net
        head_in = 2*base_encoder.output_shape[0]
        self.head = nn.Sequential(
                nn.Linear(head_in, hidden_dim), nn.Tanh(),
                nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
                nn.Linear(hidden_dim, 1),
                )

    def forward(self, state_1, state_2):
        state_1_enc = self.base_net(state_1)
        state_2_enc = self.base_net(state_2)
        x = torch.cat([state_1_enc, state_2_enc], dim=-1)
        return self.head(x)


class DynDistIL(NestedAlgo):
    def __init__(self, agent_updater=PPO(), get_discrim=None):
        super().__init__([DynDist(), agent_updater], designated_rl_idx=1)


class DynDist(BaseIRLAlgo):
    def init(self, policy, args):
        super().init(policy, args)
        if not (0.0 <= args.sample_factor <= 1.0):
            raise ValueError('Sample factor must be between 0 and 1')
        base_net = self.policy.get_base_net_fn(
                rutils.get_obs_shape(self.policy.obs_space))
        self.dist_func = DistFunc(base_net).to(self.args.device)

        self.opt = optim.Adam(self.dist_func.parameters(),
                lr=self.args.dist_lr)
        self.debug_viz = mutils.get_visualizer(args, policy, args.save_dir,
                args.dd_viz_type)

        self.exp_buff_size = args.exp_buff_size
        if self.exp_buff_size == -1:
            self.exp_buff_size = mutils.get_default_buff_size(args)
        self.agent_exp = deque(maxlen=self.exp_buff_size)


    def _compute_dist_loss(self, data_batch):
        pred_dist = self.dist_func(data_batch['state_1'], data_batch['state_2'])
        return F.mse_loss(pred_dist.view(-1), data_batch['dist'].view(-1))

    def _debug_dist(self, states, action):
        goal_states = []
        for state in states:
            _, _, goal_pos = convert_to_graph(state)
            tmp_env = get_env_for_pos(goal_pos, goal_pos, self.args)
            goal_state = get_grid_obs_for_env(tmp_env)
            goal_states.append(torch.FloatTensor(goal_state.transpose(2, 0, 1)))

        goal_states = torch.stack(goal_states)
        states = states.to(self.args.device)
        goal_states = goal_states.to(self.args.device)
        return self.dist_func(states, goal_states)

    def first_train(self, log, eval_policy):
        self.dist_func.train()
        losses = []
        for epoch_i in tqdm(range(self.args.pre_num_epochs)):
            epoch_losses = []
            for expert_batch in self.expert_train_loader:
                expert_batch = iutils.dict_to_device(expert_batch, self.args.device)
                loss = self._compute_dist_loss(expert_batch)
                epoch_losses.append(loss.item())

                self.opt.zero_grad()
                loss.backward()
                self.opt.step()

            avg_loss = np.mean(epoch_losses)
            losses.append(avg_loss)
            print('Epoch %i: Loss %.5f' % (epoch_i, avg_loss))

        rutils.plot_line(losses, 'dist_loss.png', self.args, not self.args.no_wb)
        self.debug_viz.plot(0, [], {
            'dist': self._debug_dist
            })

    def _get_traj_dataset(self, traj_load_path):
        return PairsDataset(traj_load_path, self.args.dd_norm_factor,
                self.args.sample_factor)

    def _update_reward_func(self, storage):
        if not self.args.dd_with_agent:
            return {}

        use_data = iutils.mix_data(self.agent_exp, self.expert_dataset,
                self.args.exp_sample_size, 0.5)

        sampler = BatchSampler(SubsetRandomSampler(range(self.args.exp_sample_size)),
                self.args.traj_batch_size, drop_last=True)

        use_data = iutils.convert_list_dict(use_data,
                self.args.device)
        log_vals = defaultdict(list)

        self.dist_func.train()
        for batch_idx in sampler:
            batch = iutils.select_idx_from_dict(batch_idx, use_data)
            loss = self._compute_dist_loss(batch)

            self.opt.zero_grad()
            loss.backward()
            self.opt.step()
            log_vals['dist_loss'].append(loss.item())

        for k in log_vals:
            log_vals[k] = np.mean(log_vals[k])
        if self.update_i % self.args.dd_viz_interval == 0:
            self.debug_viz.plot(self.update_i, [], {
                'dist': self._debug_dist
                })
        return log_vals

    def _get_reward(self, state, action, masks, add_inputs, rew_hid):
        with torch.no_grad():
            self.dist_func.eval()
            goal_state = add_inputs['goal']
            distance = self.dist_func(state, goal_state)
            return -1.0 * distance, None, {}

    def get_env_settings(self, args):
        settings = super().get_env_settings(args)
        settings.ret_raw_obs = True

        settings.include_info_keys.extend([
                ('ep_found_goal', lambda _: (1,)),
                ('final_obs', lambda env: rutils.get_obs_shape(env.observation_space))
                ])
        return settings

    def on_traj_finished(self, trajs):
        super().on_traj_finished(trajs)
        if not self.args.dd_with_agent:
            return

        obs, actions, masks, add_data, rewards = iutils.traj_to_tensor(trajs,
                self.args.device)
        is_successes, end_t = mutils.get_success(add_data, masks)
        obs = mutils.get_full_obs(obs, add_data, end_t)

        # Filter out any non-successful trajectories.
        obs = [x for x, is_success in zip(obs, is_successes)
                if is_success]
        for traj_obs in obs:
            self.agent_exp.extend(generate_pairs(traj_obs,
                self.args.dd_norm_factor, self.args.sample_factor))


    def get_add_args(self, parser):
        super().get_add_args(parser)
        parser.add_argument('--no-wb', default=False, action='store_true')

        #########################################
        # New args
        parser.add_argument('--dist-lr', type=float, default=0.001)
        parser.add_argument('--pre-num-epochs', type=int, default=10)
        parser.add_argument('--exp-buff-size', type=int, default=-1)
        parser.add_argument('--sample-factor', type=float, default=0.0)
        parser.add_argument('--dd-with-agent', type=str2bool, default=False)
        parser.add_argument('--dd-norm-factor', type=int, default=100,
                help=(
                    'Normalizing factor for distances ',
                    'Should probably be the max length of the trajectory',
                    ))
        parser.add_argument('--dd-viz-type', type=str, default=None)
        parser.add_argument('--dd-viz-interval', type=int, default=10)
        parser.add_argument('--exp-sample-size', type=int, default=128)
