import sys
import random

import torch
import minari
import numpy as np
import faiss
import gymnasium as gym
import gymnasium_robotics
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.util.util import make_vec_env
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.utils import get_schedule_fn

from envs.metaworld import MetaWorldSawyerEnv
from envs.wrappers import FlatternEnv, flattern_observation
from preforl.utils import biased_bce_with_logits


def get_env_name(env_name):
    if 'umaze' in env_name:
        return 'PointMaze_UMaze-v3'
    elif 'medium' in env_name:
        return 'PointMaze_Medium-v3'
    elif 'large' in env_name:
        return 'PointMaze_Large-v3'


def get_max_steps(env_name):
    if 'umaze' in env_name:
        return 300
    elif 'medium' in env_name:
        return 600
    elif 'large' in env_name:
        return 800

def get_env_dims(env_name):
    return 8


class Maze2DTrainer(object):
    def __init__(self,
                 algorithm,
                 env_name,
                 net_arch,
                 num_algo_iters=100,
                 PREFORL_num_samples=100,
                 CPL_density=8,
                 PREFORL_num_epochs=8,
                 PREFORL_batch_size=100,
                 PREFORL_segment_length=64,
                 alpha=0.1,
                 contrastive_bias=0.25,
                 lr=3e-4,
                 seed=123,
                 shadow_noise=0.01,
                 shadow_ratio=0.4,
                 bc_coeff=0.5,
                 device=torch.device('cpu'),
                 ) -> None:

        self.algorithm = algorithm
        self.env_name = get_env_name(env_name)
        self.max_episode_steps = get_max_steps(env_name)
        self.PREFORL_num_samples = PREFORL_num_samples  # Number of sampled episodes
        self.CPL_density = CPL_density  # Density of preference datasets (CPL only)
        self.PREFORL_num_epochs = PREFORL_num_epochs  # Number of epochs of PREFORL optimizing
        self.PREFORL_batch_size = PREFORL_batch_size  # Batch size of PREFORL optimizing
        self.PREFORL_segment_length = PREFORL_segment_length  # Segment length of PREFORL/CPL loss
        self.num_algo_iters = num_algo_iters  # Number of algorithm iterations
        self.alpha = alpha  # Alpha in PREFORL/CPL loss
        self.contrastive_bias = contrastive_bias  # Contrastive bias in PREFORL/CPL loss
        self.bc_coeff = bc_coeff  # Coefficient of BC loss
        self.lr = lr  # Learning rate for optimizing policy
        self.device = device
        self.rng = np.random.default_rng(seed)
        self.eval_rng = np.random.default_rng(seed + 1)

        # Load dataset
        self.dataset = minari.load_dataset(f'D4RL/pointmaze/{env_name}-v2', download=True)
        episodes = self.dataset.sample_episodes(self.dataset.total_episodes)
        self.episodes = np.array(episodes)
        self.data_length = len(self.episodes)

        env = self.dataset.recover_environment()
        env = FlatternEnv(env)

        self.obs_handle = flattern_observation
        post_wrappers = [lambda env, _: RolloutInfoWrapper(env)]
        post_wrappers = [lambda env, _: FlatternEnv(env)] + post_wrappers

        # Build the vector environment
        self.venv = make_vec_env(
            env_name=self.env_name,
            max_episode_steps=self.max_episode_steps,
            rng=self.rng,
            n_envs=10,
            parallel=True,
            post_wrappers=post_wrappers,
        )

        # Init a zero policy and optimizer
        self.policy = ActorCriticPolicy(
            observation_space=env.observation_space,
            action_space=env.action_space,
            lr_schedule=get_schedule_fn(lr),
            net_arch=net_arch,
        ).to(self.device)

        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self.lr)

        # Build evaluation environment
        self.eval_venv = make_vec_env(
            env_name=self.env_name,
            max_episode_steps=self.max_episode_steps,
            rng=self.eval_rng,
            n_envs=10,
            parallel=True,
            post_wrappers=post_wrappers,
            env_make_kwargs={"reset_target": True},
        )

        # Noises of building fake-actions
        self.shadow_noise = shadow_noise
        self.shadow_ratio = shadow_ratio

        # Number of evaluation episodes
        self.eval_nums = 50

        if self.algorithm == 'PREFORL-NN':
            self.states = []
            self.storage = []
            for ep in self.episodes:
                for s, a, r in zip(flattern_observation(ep.observations), ep.actions, ep.rewards):
                    self.states.append(s)
                    self.storage.append({'s': s, 'a': a, 'r': r})

            self.states = np.array(self.states)
            self.storage = np.array(self.storage)

            self.index = faiss.IndexFlatL2(get_env_dims(env_name))
            nlist = 100
            self.index = faiss.IndexIVFFlat(self.index, get_env_dims(env_name), nlist)
            self.index.train(np.array(self.states))
            self.index.nprobe = 10
            self.index.add(np.array(self.states))

    def create_contrastive_batch(self, left_sample, right_sample, segment_length):

        tau1 = left_sample
        tau2 = right_sample
        l1, l2 = tau1['l'], tau2['l']

        if self.algorithm in ('PREFORL', 'PREFORL-NN'):
            length = min(l1, l2, segment_length)
            idx1 = np.random.choice(l1, size=length, replace=False)
            idx2 = np.random.choice(l2, size=length, replace=False)
            idx1 = np.sort(idx1)
            idx2 = np.sort(idx2)

            batch = dict()
            batch['pos_obs'] = torch.from_numpy(tau1['s'][idx1])
            batch['neg_obs'] = torch.from_numpy(tau2['s'][idx2])
            batch['pos_action'] = torch.from_numpy(tau1['a'][idx1])
            batch['neg_action'] = torch.from_numpy(tau2['a'][idx2])
            batch['length'] = length * 2

        elif self.algorithm == 'CPL':
            length = min(l1, l2, segment_length)
            idx1 = random.randint(0, l1 - length)
            idx2 = random.randint(0, l2 - length)

            batch = dict()
            batch['pos_obs'] = torch.from_numpy(tau1['s'][idx1: idx1 + length])
            batch['neg_obs'] = torch.from_numpy(tau2['s'][idx2: idx2 + length])
            batch['pos_action'] = torch.from_numpy(tau1['a'][idx1: idx1 + length])
            batch['neg_action'] = torch.from_numpy(tau2['a'][idx2: idx2 + length])
            batch['length'] = length * 2

            reward_left_total = tau1['r'][idx1: idx1 + length].sum()
            reward_right_total = tau2['r'][idx2: idx2 + length].sum()
            if reward_left_total > reward_right_total:
                batch['label'] = torch.tensor(0.0)
            else:
                batch['label'] = torch.tensor(1.0)

        return batch

    def get_contrastive_loss(self, batches):

        lengths = []
        obs, action = [], []
        pos_obs, pos_action = [], []
        labels = []
        for batch in batches:
            obs.append(batch['pos_obs'])
            obs.append(batch['neg_obs'])
            action.append(batch['pos_action'])
            action.append(batch['neg_action'])
            pos_obs.append(batch['pos_obs'])
            pos_action.append(batch['pos_action'])
            lengths.append(batch['length'])
            labels.append(batch['label'])

        obs = torch.cat(obs, dim=0).to(self.device)
        action = torch.cat(action, dim=0).to(self.device)
        pos_obs = torch.cat(pos_obs, dim=0).to(self.device)
        pos_action = torch.cat(pos_action, dim=0).to(self.device)

        dist = self.policy.get_distribution(obs)
        log_prob = dist.log_prob(action)
        adv = self.alpha * log_prob
        advs = adv.split(lengths)

        pos_dist = self.policy.get_distribution(pos_obs)
        pos_log_prob = pos_dist.log_prob(pos_action)
        bc_loss = -pos_log_prob.mean()

        adv1s, adv2s = [], []
        for adv in advs:
            adv1, adv2 = torch.chunk(adv, 2, dim=0)
            adv1 = adv1.sum(dim=-1)
            adv2 = adv2.sum(dim=-1)
            adv1s.append(adv1.unsqueeze(dim=0))
            adv2s.append(adv2.unsqueeze(dim=0))

        adv1s = torch.stack(adv1s)
        adv2s = torch.stack(adv2s)
        labels = torch.stack(labels).to(self.device)

        contrastive_loss = biased_bce_with_logits(adv1s, adv2s, labels, bias=self.contrastive_bias)
        return contrastive_loss + self.bc_coeff * bc_loss

    def train(self):

        for algo_iters in range(1, self.num_algo_iters + 1):

            #### TRAINING POLICY
            for _ in range(self.PREFORL_num_epochs):

                samples = self.episodes[np.random.choice(self.data_length, size=min(self.data_length, self.PREFORL_batch_size), replace=False)]
                
                # For each sample, compare with other samples to build preference datasets
                if self.algorithm == 'CPL':
                    positive_samples = []
                    negative_samples = []
                    for ep in samples:
                        indices = np.random.choice(self.data_length, size=self.PREFORL_density, replace=False)
                        for idx in indices:
                            comp_ep = self.episodes[idx]
                            positive_samples.append({'s': ep.observations, 'a': ep.actions, 'r': ep.rewards, 'l': ep.actions.shape[0]})
                            negative_samples.append({'s': comp_ep.observations, 'a': comp_ep.actions, 'r': comp_ep.rewards, 'l': comp_ep.actions.shape[0]})

                # For each sample, create preference datasets using noisy actions
                elif self.algorithm == 'PREFORL':
                    positive_samples = []
                    negative_samples = []
                    for ep in samples:
                        noisy_action = np.stack([self.venv.action_space.sample() * self.shadow_noise for _ in range(ep.actions.shape[0])], axis=0)
                        mask = np.random.choice([0, 1], size=(ep.actions.shape[0], 1), p=[1 - self.shadow_ratio, self.shadow_ratio])
                        masked_noisy_action = noisy_action * mask
                        negative_actions = ep.actions + masked_noisy_action
                        positive_samples.append({'s': flattern_observation(ep.observations), 'a': ep.actions, 'r': ep.rewards, 'l': ep.actions.shape[0]})
                        negative_samples.append({'s': flattern_observation(ep.observations), 'a': negative_actions, 'r': ep.rewards, 'l': ep.actions.shape[0]})
                        # positive_samples.append({'s': ep.observations, 'a': ep.actions, 'r': ep.rewards, 'l': ep.actions.shape[0]})
                        # negative_samples.append({'s': ep.observations, 'a': negative_actions, 'r': ep.rewards, 'l': ep.actions.shape[0]})

                # For each sample, create preference datasets using KNN
                elif self.algorithm == 'PREFORL-NN':
                    positive_samples = []
                    negative_samples = []
                    for ep in samples:
                        ep_observation = flattern_observation(ep.observations)

                        D, I = self.index.search(np.array(ep_observation), 10)
                        ep_I = I[:, :1]
                        I = np.array(I[:, 1:], dtype=np.int32)
                        negative_states = []
                        negative_actions = []
                        negative_rewards = []
                        sub_count = 0
                        for t in range(ep_I.shape[0]):
                            r = self.storage[ep_I[t][0]]['r']
                            found_negative = False
                            for i in I[t]:
                                if self.storage[i]['r'] < r and np.abs(self.storage[i]['r'] - r) > 0.1 * np.abs(r):
                                    found_negative = True
                                    sub_count += 1
                                    break
                            if found_negative:
                                negative_states.append(self.storage[i]['s'])
                                negative_actions.append(self.storage[i]['a'])
                                negative_rewards.append(self.storage[i]['r'])
                            else:
                                negative_states.append(self.storage[ep_I[t][0]]['s'])
                                negative_actions.append(self.storage[ep_I[t][0]]['a'])
                                negative_rewards.append(self.storage[ep_I[t][0]]['r'])

                        positive_samples.append({'s': ep_observation, 'a': ep.actions, 'r': ep.rewards, 'l': ep.actions.shape[0]})
                        negative_samples.append({'s': ep_observation, 'a': np.array(negative_actions), 'r': np.array(ep.rewards), 'l': ep.actions.shape[0]})

                self.policy.train()

                contrastive_batches = []
                for positive_sample, negative_sample in list(zip(positive_samples, negative_samples)):

                    if self.algorithm == 'CPL':
                        segment_length = self.PREFORL_segment_length
                        cpl_batch = self.create_contrastive_batch(positive_sample, negative_sample, segment_length=segment_length)

                    elif self.algorithm in ('PREFORL', 'PREFORL-NN'):
                        flip = bool(random.randint(0, 1))
                        segment_length = self.PREFORL_segment_length
                        if not flip:
                            cpl_batch = self.create_contrastive_batch(positive_sample, negative_sample, segment_length=segment_length)
                            cpl_batch['label'] = torch.tensor(0.0)
                        else:
                            cpl_batch = self.create_contrastive_batch(negative_sample, positive_sample, segment_length=segment_length)
                            cpl_batch['label'] = torch.tensor(1.0)

                    contrastive_batches.append(cpl_batch)

                contrastive_loss = self.get_contrastive_loss(contrastive_batches)
                
                self.optimizer.zero_grad()
                contrastive_loss.backward()
                self.optimizer.step()

            #### EVALUATION
            self.evaluate_policy(algo_iters=algo_iters, eval_nums=self.eval_nums)

    def evaluate_policy(self, algo_iters, eval_nums):

        evaluation_rollouts = rollout.rollout(
                self.policy,
                self.eval_venv,
                rollout.make_sample_until(min_timesteps=None, min_episodes=eval_nums),
                rng=self.eval_rng,
                exclude_infos=False,
                deterministic_policy=True,
            )[:eval_nums]

        # record returns
        returns = []
        success = 0
        for traj in evaluation_rollouts:
            returns.append(sum(traj.rews))
            success += np.any(traj.rews > 0.0)

        print(f'[EVAL][{algo_iters}] \
                [EVAL_RETURNS][{np.mean(returns)}] \
                [EVAL_SUCCESS][{success} / {eval_nums}]')
        sys.stdout.flush()
