import copy
import sys
import random

import dill
import torch
import numpy as np
from imitation.data import rollout
from imitation.util.util import make_vec_env
from imitation.data.wrappers import RolloutInfoWrapper

from preforl.buffer import TrajectoryRankingBuffer
from preforl.utils import biased_bce_with_logits
from envs.metaworld import MetaWorldSawyerEnv


class MetaWorldTrainer(object):
    def __init__(self,
                 env_name,
                 observation_mode,
                 max_episode_steps,
                 policy,
                 expert_path,
                 num_expert_demos=50,
                 num_algo_iters=100,
                 PREFORL_samples=20,
                 PREFORL_epochs=8,
                 PREFORL_batch_size=20,
                 PREFORL_segment_length=64,
                 alpha=0.1,
                 contrastive_bias=0.25,
                 lr=0.0003,
                 shadow_noise=0.01,
                 shadow_ratio=0.4,
                 seed=123,
                 device=torch.device('cpu'),
                 **kwargs,
                 ) -> None:

        # Hyper-parameters
        self.num_algo_iters = num_algo_iters  # Number of algorithm iterations
        self.alpha = alpha  # Alpha in CPL loss
        self.contrastive_bias = contrastive_bias  # Contrastive bias in CPL loss
        self.lr = lr  # Learning rate for optimizing policy
        self.device = device

        # Setting up the environment
        self.env_name = env_name
        self.max_episode_steps = max_episode_steps
        self.rng = np.random.default_rng(seed)
        self.eval_rng = np.random.default_rng(seed + 1)
        self.observation_mode = observation_mode

        if self.observation_mode == 'image':
            env_name = f'meta_image_{env_name}'
        self.venv = make_vec_env(
            env_name=env_name,
            max_episode_steps=max_episode_steps,
            rng=self.rng,
            n_envs=10,
            parallel=True,
            post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],  # for computing rollout
            env_make_kwargs={'early_termination': True, 'sparse': True},
        )

        self.eval_venv = make_vec_env(
            env_name=env_name,
            max_episode_steps=max_episode_steps,
            rng=self.eval_rng,
            n_envs=10,
            parallel=True,
            post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],  # for computing rollout
            env_make_kwargs={'early_termination': True, 'sparse': True},
        )

        # Noises of building fake-actions
        self.shadow_noise = shadow_noise
        self.shadow_ratio = shadow_ratio

        # PREFORL hyper-parameters
        self.PREFORL_samples = PREFORL_samples  # Number of policy evaluations
        self.PREFORL_epochs = PREFORL_epochs  # Number of epochs of PREFORL optimizing
        self.PREFORL_batch_size = PREFORL_batch_size  # Batch size of PREFORL optimizing
        self.PREFORL_segment_length = PREFORL_segment_length  # Segment length of CPL loss

        # Initialize policy
        self.policy = copy.deepcopy(policy).to(self.device)
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self.lr)

        # Load expert demonstration
        self.num_expert_demos = num_expert_demos
        expert_data = torch.load(expert_path, pickle_module=dill)[:self.num_expert_demos]

        # Build expert/shadow trajector buffers
        self.expert_buffer = TrajectoryRankingBuffer(expert_data, discriminator=None, init_q=1.0)
        self.shadow_buffer = TrajectoryRankingBuffer(expert_data, discriminator=None, init_q=0.0)

        # Iteration-based evaluation
        self.eval_nums = 50

    def get_PREFORL_loss(self, batches):

        lengths = []
        obs, action = [], []
        labels = []
        for batch in batches:
            obs.append(batch['pos_obs'])
            obs.append(batch['neg_obs'])
            action.append(batch['pos_action'])
            action.append(batch['neg_action'])
            lengths.append(batch['length'])
            labels.append(batch['label'])

        obs = torch.cat(obs, dim=0).to(self.device)
        action = torch.cat(action, dim=0).to(self.device)

        dist = self.policy.get_distribution(obs)
        log_prob = dist.log_prob(action)
        adv = self.alpha * log_prob
        advs = adv.split(lengths)

        adv1s, adv2s = [], []
        for adv in advs:
            adv1, adv2 = torch.chunk(adv, 2, dim=0)
            adv1 = adv1.sum(dim=-1)
            adv2 = adv2.sum(dim=-1)
            adv1s.append(adv1.unsqueeze(dim=0))
            adv2s.append(adv2.unsqueeze(dim=0))

        adv1s = torch.stack(adv1s)
        adv2s = torch.stack(adv2s)
        labels = torch.stack(labels).to(self.device)

        PREFORL_loss = biased_bce_with_logits(adv1s, adv2s, labels, bias=self.contrastive_bias)
        return PREFORL_loss

    def create_PREFORL_batch(self, pos_sample, neg_sample, segment_length):

        q1, tau1 = pos_sample
        q2, tau2 = neg_sample
        l1, l2 = tau1['l'], tau2['l']

        length = min(l1, l2, segment_length)
        idx1 = np.random.choice(l1, size=length, replace=False)
        idx2 = np.random.choice(l2, size=length, replace=False)
        idx1 = np.sort(idx1)
        idx2 = np.sort(idx2)

        batch = dict()
        batch['pos_obs'] = torch.from_numpy(tau1['s'][idx1])
        batch['neg_obs'] = torch.from_numpy(tau2['s'][idx2])
        batch['pos_action'] = torch.from_numpy(tau1['a'][idx1])
        batch['neg_action'] = torch.from_numpy(tau2['a'][idx2])
        batch['length'] = length * 2

        return batch

    def train(self):

        for algo_iters in range(self.num_algo_iters):

            #### TRAINING POLICY

            num_samples = min(len(self.expert_buffer), self.PREFORL_batch_size)
            for _ in range(self.PREFORL_epochs):
                positive_samples = self.expert_buffer.sample(mode='random', k=num_samples)
                negative_samples = self.shadow_buffer.sample(mode='random', k=num_samples)

                for q, tau in negative_samples:
                    l = tau['a'].shape[0]
                    mask = np.random.choice([0, 1], size=(l, 1), p=[1 - self.shadow_ratio, self.shadow_ratio])
                    noisy_action = np.stack([self.venv.action_space.sample() * self.shadow_noise for _ in range(l)], axis=0)
                    masked_noisy_action = noisy_action * mask
                    tau['a'] += masked_noisy_action

                self.policy.train()

                PREFORL_batches = []
                for positive_sample, negative_sample in list(zip(positive_samples, negative_samples)):
                    flip = bool(random.randint(0, 1))
                    segment_length = self.PREFORL_segment_length
                    if not flip:
                        PREFORL_batch = self.create_PREFORL_batch(positive_sample, negative_sample, segment_length=segment_length)
                        PREFORL_batch['label'] = torch.tensor(0.0)
                    else:
                        PREFORL_batch = self.create_PREFORL_batch(negative_sample, positive_sample, segment_length=segment_length)
                        PREFORL_batch['label'] = torch.tensor(1.0)
                    PREFORL_batches.append(PREFORL_batch)

                PREFORL_loss = self.get_PREFORL_loss(PREFORL_batches)

                self.optimizer.zero_grad()
                PREFORL_loss.backward()
                self.optimizer.step()

            #### EVALUATION
            self.evaluate_policy(algo_iters=algo_iters, eval_nums=self.eval_nums)

    def evaluate_policy(self, algo_iters, eval_nums):

        evaluation_rollouts = rollout.rollout(
                self.policy,
                self.eval_venv,
                rollout.make_sample_until(min_timesteps=None, min_episodes=eval_nums),
                rng=self.eval_rng,
                exclude_infos=False,
                deterministic_policy=True,
            )[:eval_nums]

        # find success trajectories
        num_success = 0
        for traj in evaluation_rollouts:
            num_success += int(np.any([bool(info['success']) for info in traj.infos]))

        print(f'[EVAL][{algo_iters}] \
                [EVAL_SUCC][{num_success} / {eval_nums}]')
        sys.stdout.flush()
