import sys
import random

import torch
import numpy as np
import gym
import gymnasium
import d4rl
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.util.util import make_vec_env
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.utils import get_schedule_fn

from envs.metaworld import MetaWorldSawyerEnv
from preforl.utils import biased_bce_with_logits


THRESHOLD = {
    'halfcheetah-medium': 4909.1,
    'halfcheetah-medium-expert': 10703.4,
    'hopper-medium': 1621.5,
    'hopper-medium-expert': 3561.9,
    'walker2d-medium': 3697.8,
    'walker2d-medium-expert': 4924.8,
}

def get_env_name(env_name):
    if 'halfcheetah' in env_name:
        return 'HalfCheetah-v2'
    elif 'hopper' in env_name:
        return 'Hopper-v2'
    elif 'walker2d' in env_name:
        return 'Walker2d-v2'
    else:
        raise NotImplementedError(f"Environment '{env_name}' is not implemented.")


class MuJoCoTrainer(object):
    def __init__(self,
                 env_name,
                 net_arch,
                 num_algo_iters=100,
                 PREFORL_num_samples=100,
                 PREFORL_density=8,
                 PREFORL_num_epochs=8,
                 PREFORL_batch_size=100,
                 PREFORL_segment_length=64,
                 alpha=0.1,
                 contrastive_bias=0.25,
                 lr=3e-4,
                 seed=123,
                 shadow_noise=0.01,
                 shadow_ratio=0.4,
                 bc_coeff=0.5,
                 device=torch.device('cpu'),
                 ) -> None:

        self.env_name = get_env_name(env_name)
        self.max_episode_steps = 1000
        self.PREFORL_num_samples = PREFORL_num_samples  # Number of sampled episodes
        self.PREFORL_density = PREFORL_density  # Density of preference datasets
        self.PREFORL_num_epochs = PREFORL_num_epochs  # Number of epochs of PREFORL optimizing
        self.PREFORL_batch_size = PREFORL_batch_size  # Batch size of PREFORL optimizing
        self.PREFORL_segment_length = PREFORL_segment_length  # Segment length of CPL loss
        self.num_algo_iters = num_algo_iters  # Number of algorithm iterations
        self.alpha = alpha  # Alpha in CPL loss
        self.contrastive_bias = contrastive_bias  # Contrastive bias in CPL loss
        self.lr = lr  # Learning rate for optimizing policy
        self.bc_coeff = bc_coeff  # Coefficient of BC loss
        self.device = device
        self.rng = np.random.default_rng(seed)
        self.eval_rng = np.random.default_rng(seed + 1)

        # Load dataset
        self.dataset = gym.make(env_name).get_dataset()
        self.threshold = THRESHOLD.get(env_name)
        env = gymnasium.make(self.env_name)

        episodes = []
        episode_end_ids = [-1] + list(np.where(np.logical_or(self.dataset['timeouts'], self.dataset['terminals']))[0])
        for s_id, e_id in zip(episode_end_ids[:-1], episode_end_ids[1:]):
            if e_id - s_id < PREFORL_segment_length:
                print('episode too short, drop {}/{}'.format(e_id-s_id, PREFORL_segment_length))
                continue
            new_dataset = {}
            for k in self.dataset:
                if 'metadata' in k:
                    continue
                new_dataset[k] = self.dataset[k][s_id+1:e_id+1]
            episodes.append(new_dataset)

        self.episodes = episodes
        self.data_length = len(self.episodes)

        # Use high reward data only
        episodes = []
        for ep in self.episodes:
            if sum(ep['rewards']) > self.threshold:
                episodes.append(ep)
        self.episodes = episodes
        self.data_length = len(self.episodes)

        print('data length: {}'.format(self.data_length))

        # Build the vector environment
        self.venv = make_vec_env(
            env_name=self.env_name,
            max_episode_steps=self.max_episode_steps,
            rng=self.rng,
            n_envs=10,
            parallel=True,
            post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],  # for computing rollout
        )

        # Init a zero policy and optimizer
        self.policy = ActorCriticPolicy(
            observation_space=env.observation_space,
            action_space=env.action_space,
            lr_schedule=get_schedule_fn(lr),
            net_arch=net_arch,
        ).to(self.device)

        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self.lr)

        # Build evaluation environment
        self.eval_venv = make_vec_env(
            env_name=self.env_name,
            max_episode_steps=self.max_episode_steps,
            rng=self.eval_rng,
            n_envs=10,
            parallel=True,
            post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],
        )

        # Noises of building fake-actions
        self.shadow_noise = shadow_noise
        self.shadow_ratio = shadow_ratio

        # Number of evaluation episodes
        self.eval_nums = 50
    
    def get_PREFORL_loss(self, batches):

        lengths = []
        obs, action = [], []
        pos_obs, pos_action = [], []
        labels = []
        for batch in batches:
            obs.append(batch['pos_obs'])
            obs.append(batch['neg_obs'])
            action.append(batch['pos_action'])
            action.append(batch['neg_action'])
            pos_obs.append(batch['pos_obs'])
            pos_action.append(batch['pos_action'])
            lengths.append(batch['length'])
            labels.append(batch['label'])

        obs = torch.cat(obs, dim=0).to(self.device)
        action = torch.cat(action, dim=0).to(self.device)
        pos_obs = torch.cat(pos_obs, dim=0).to(self.device)
        pos_action = torch.cat(pos_action, dim=0).to(self.device)

        dist = self.policy.get_distribution(obs)
        log_prob = dist.log_prob(action)
        adv = self.alpha * log_prob
        advs = adv.split(lengths)

        pos_dist = self.policy.get_distribution(pos_obs)
        pos_log_prob = pos_dist.log_prob(pos_action)
        bc_loss = -pos_log_prob.mean()

        adv1s, adv2s = [], []
        for adv in advs:
            adv1, adv2 = torch.chunk(adv, 2, dim=0)
            adv1 = adv1.sum(dim=-1)
            adv2 = adv2.sum(dim=-1)
            adv1s.append(adv1.unsqueeze(dim=0))
            adv2s.append(adv2.unsqueeze(dim=0))

        adv1s = torch.stack(adv1s)
        adv2s = torch.stack(adv2s)
        labels = torch.stack(labels).to(self.device)

        PREFORL_loss = biased_bce_with_logits(adv1s, adv2s, labels, bias=self.contrastive_bias)
        return PREFORL_loss + self.bc_coeff * bc_loss

    def create_PREFORL_batch(self, positive_sample, negative_sample, segment_length):

        tau1 = positive_sample
        tau2 = negative_sample
        l1, l2 = tau1['l'], tau2['l']

        length = min(l1, l2, segment_length)
        idx1 = np.random.choice(l1, size=length, replace=False)
        idx2 = np.random.choice(l2, size=length, replace=False)
        idx1 = np.sort(idx1)
        idx2 = np.sort(idx2)

        batch = dict()
        batch['pos_obs'] = torch.from_numpy(tau1['s'][idx1])
        batch['neg_obs'] = torch.from_numpy(tau2['s'][idx2])
        batch['pos_action'] = torch.from_numpy(tau1['a'][idx1])
        batch['neg_action'] = torch.from_numpy(tau2['a'][idx2])
        batch['length'] = length * 2

        return batch

    def train(self):

        for algo_iters in range(1, self.num_algo_iters + 1):

            #### TRAINING POLICY
            for _ in range(self.PREFORL_num_epochs):

                samples = [self.episodes[i] for i in np.random.choice(self.data_length, size=min(self.data_length, self.PREFORL_batch_size), replace=False)]

                # For each sample, create preference datasets using noisy actions
                positive_samples = []
                negative_samples = []
                for ep in samples:
                    noisy_action = np.stack([self.venv.action_space.sample() * self.shadow_noise for _ in range(ep['actions'].shape[0])], axis=0)
                    mask = np.random.choice([0, 1], size=(ep['actions'].shape[0], 1), p=[1 - self.shadow_ratio, self.shadow_ratio])
                    masked_noisy_action = noisy_action * mask
                    negative_actions = ep['actions'] + masked_noisy_action

                    positive_samples.append({'s': ep['observations'], 'a': ep['actions'], 'r': ep['rewards'], 'l': ep['actions'].shape[0]})
                    negative_samples.append({'s': ep['observations'], 'a': negative_actions, 'r': ep['rewards'], 'l': ep['actions'].shape[0]})

                self.policy.train()

                PREFORL_batches = []
                for positive_sample, negative_sample in list(zip(positive_samples, negative_samples)):
                    
                    flip = bool(random.randint(0, 1))
                    segment_length = self.PREFORL_segment_length
                    if not flip:
                        PREFORL_batch = self.create_PREFORL_batch(positive_sample, negative_sample, segment_length=segment_length)
                        PREFORL_batch['label'] = torch.tensor(0.0)
                    else:
                        PREFORL_batch = self.create_PREFORL_batch(negative_sample, positive_sample, segment_length=segment_length)
                        PREFORL_batch['label'] = torch.tensor(1.0)

                    PREFORL_batches.append(PREFORL_batch)

                PREFORL_loss = self.get_PREFORL_loss(PREFORL_batches)

                self.optimizer.zero_grad()
                PREFORL_loss.backward()
                self.optimizer.step()

            #### EVALUATION
            self.evaluate_policy(algo_iters=algo_iters, eval_nums=self.eval_nums)

    def evaluate_policy(self, algo_iters, eval_nums):

        self.policy.eval()
        evaluation_rollouts = rollout.rollout(
                self.policy,
                self.eval_venv,
                rollout.make_sample_until(min_timesteps=None, min_episodes=eval_nums),
                rng=self.eval_rng,
                exclude_infos=False,
                deterministic_policy=True,
            )[:eval_nums]
        self.policy.train()

        # find success trajectories and record returns
        num_success = 0
        returns = []
        for traj in evaluation_rollouts:
            if sum(traj.rews) > self.threshold:
                num_success += 1
            returns.append(sum(traj.rews))

        print(f'[EVAL][{algo_iters}] \
                [EVAL_SUCC][{num_success} / {eval_nums}] \
                [EVAL_RETURNS][{np.mean(returns)}]')
        sys.stdout.flush()
