import random
import heapq

import numpy as np
from torch.utils.data import Dataset

TINY = 1e-12


class TrajectoryBuffer(Dataset):
    def __init__(self, rollouts):
        self.trajectories = []
        for traj in rollouts:
            s = traj.obs[:-1]
            a = traj.acts
            self.trajectories.append({'s': s, 'a': a})

    def __getitem__(self, index):
        return self.trajectories[index]
    
    def __len__(self):
        return len(self.trajectories)


class TrajectoryRankingBuffer(Dataset):

    @staticmethod
    def sort_fn(traj: tuple):
        q, tau = traj
        return q

    def __init__(self, rollouts, discriminator, init_q=None):
        self.trajectories = []
        self.quality = []
        self.discriminator = discriminator
        self.size = 0
        results = []

        for idx, rollout in enumerate(rollouts):
            l = int(rollout.infos[-1]['episode']['l'])
            s = rollout.obs[:l]
            a = rollout.acts
            r = float(rollout.infos[-1]['episode']['r'])
            dense_r = 0
            for i in range(len(rollout)):
                if 'dense_reward' in rollout.infos[i]:
                    dense_r += rollout.infos[i]['dense_reward']
            tau = {
                's': s,
                'a': a,
                'r': r,
                'l': l,
                'dense_r': dense_r,
            }
            assert tau['s'].shape[0] == tau['l']
            if init_q is None:
                q = discriminator.predict(tau).item()
                l = tau['l']
                results.append(f'{q:.3f}-{l}-{dense_r:.3f}')
            else:
                q = init_q + idx * TINY
            if q not in self.quality:
                self.quality.append(q)
                self.size += tau['l']
                self.trajectories.append((q, tau))

        heapq.heapify(self.trajectories)
        print(results)

    def find(self, operator, base_q):
        assert operator in '<>', 'Operator must be one of \'<\' or \'>\'.'
        
        trajectories = []
        quality = []
        rewards = []
        steps = 0
        for idx, traj in enumerate(self.trajectories):
            q, tau = traj
            steps += tau['l']
            if operator == '>':
                if q > base_q:
                    trajectories.append(traj)
                    quality.append(q)
                    rewards.append(tau['r'])
            elif operator == '<':
                if q < base_q:
                    trajectories.append(traj)
                    quality.append(q)
                    rewards.append(tau['r'])

        trajectories.sort(key=self.sort_fn)
        if quality:
            stats = {
                'q': sum(quality) / len(quality), 
                'min_q': min(quality), 
                'max_q': max(quality), 
                'all_q': quality,
                'rewards': rewards,
                'steps': steps,
            }
        else:
            stats = {
                'all_q': [],
                'rewards': [],
                'steps': steps,
            }
        return trajectories, stats

    def merge(self, trajectories, stats):
        if stats['all_q']:
            self.trajectories += trajectories
            self.quality += stats['all_q']
            self.size += stats['steps']

    def sample(self, mode, k=0):
        if mode == 'random':
            return random.choices(self.trajectories, k=min(k, len(self.trajectories)))
        elif mode == 'max':
            return heapq.nlargest(k, self.trajectories, key=self.sort_fn)
        elif mode == 'min':
            return heapq.nsmallest(k, self.trajectories, key=self.sort_fn)
        elif mode == 'all':
            return random.choices(self.trajectories, k=len(self.trajectories))

    @property
    def q(self):
        return sum(self.quality) / len(self.quality) if len(self.quality) > 0 else 0

    def __getitem__(self, index):
        return self.trajectories[index]
    
    def __len__(self):
        return len(self.trajectories)


def find_success(rollouts, discriminator):
    success_trajectories = []
    others_rollouts = []
    all_q = []
    steps = 0
    for traj in rollouts:
        success = np.any([bool(info['success']) for info in traj.infos])
        # success = bool(np.any(traj.infos['success']))
        if success:
            tau = {
                's': traj.obs[:-1],
                'a': traj.acts,
                'r': float(traj.infos[-1]['episode']['r']),
                'l': int(traj.infos[-1]['episode']['l'])
            }
            q = discriminator.predict(tau).item()
            success_trajectories.append((q, tau))
            all_q.append(q)
            steps += len(traj)
        else:
            others_rollouts.append(traj)
    stats = {
        'all_q': all_q,
        'steps': steps,
    }
    return success_trajectories, others_rollouts, stats


def find_failure(rollouts, discriminator):
    failed_trajectories = []
    others_rollouts = []
    all_q = []
    steps = 0
    for traj in rollouts:
        failed = np.all([not bool(info['success']) for info in traj.infos])
        if failed:
            tau = {
                's': traj.obs[:-1],
                'a': traj.acts,
                'r': float(traj.infos[-1]['episode']['r']),
                'l': int(traj.infos[-1]['episode']['l'])
            }
            q = discriminator.predict(tau).item()
            failed_trajectories.append((q, tau))
            all_q.append(q)
            steps += len(traj)
        else:
            others_rollouts.append(traj)
    stats = {
        'all_q': all_q,
        'steps': steps,
    }
    return failed_trajectories, others_rollouts, stats



class NewTrajectoryRankingBuffer(Dataset):

    @staticmethod
    def sort_fn(traj: tuple):
        q, tau = traj
        return q

    def __init__(self, rollouts, discriminator, init_q=None, cutoff=50):
        self.trajectories = []
        self.quality = []
        self.discriminator = discriminator
        self.size = 0
        self.cutoff = cutoff
        results = []

        
        for idx, rollout in enumerate(rollouts):

            # filtering the rollout
            indices = np.random.choice(len(rollout), size=cutoff, replace=False)
            indices.sort()
            indices[-1] = len(rollout) - 1
            l = cutoff
            s = rollout.obs[indices]
            a = rollout.acts[indices]
            r = float(rollout.infos[-1]['episode']['r'])
            dense_r = 0
            if 'dense_reward' in rollout.infos[-1]:
                for i in indices:
                    dense_r += rollout.infos[i]['dense_reward']
            tau = {
                's': s,
                'a': a,
                'r': r,
                'l': l,
                'dense_r': dense_r,
            }
            assert tau['s'].shape[0] == tau['l']
            if init_q is None:
                q = discriminator.predict(tau).item()
                l = tau['l']
                results.append(f'{q:.3f}-{l}-{dense_r:.3f}')
            else:
                q = init_q + idx * TINY
            if q not in self.quality:
                self.quality.append(q)
                self.size += tau['l']
                self.trajectories.append((q, tau))

        heapq.heapify(self.trajectories)
        print(results)

    def find(self, operator, base_q):
        assert operator in '<>', 'Operator must be one of \'<\' or \'>\'.'
        
        trajectories = []
        quality = []
        rewards = []
        steps = 0
        for idx, traj in enumerate(self.trajectories):
            q, tau = traj
            steps += tau['l']
            if operator == '>':
                if q > base_q:
                    trajectories.append(traj)
                    quality.append(q)
                    rewards.append(tau['r'])
            elif operator == '<':
                if q < base_q:
                    trajectories.append(traj)
                    quality.append(q)
                    rewards.append(tau['r'])

        trajectories.sort(key=self.sort_fn)
        if quality:
            stats = {
                'q': sum(quality) / len(quality), 
                'min_q': min(quality), 
                'max_q': max(quality), 
                'all_q': quality,
                'rewards': rewards,
                'steps': steps,
            }
        else:
            stats = {
                'all_q': [],
                'rewards': [],
                'steps': steps,
            }
        return trajectories, stats

    def merge(self, trajectories, stats):
        if stats['all_q']:
            self.trajectories += trajectories
            self.quality += stats['all_q']
            self.size += stats['steps']

    def sample(self, mode, k=0):
        if mode == 'random':
            return random.choices(self.trajectories, k=min(k, len(self.trajectories)))
        elif mode == 'max':
            return heapq.nlargest(k, self.trajectories, key=self.sort_fn)
        elif mode == 'min':
            return heapq.nsmallest(k, self.trajectories, key=self.sort_fn)
        elif mode == 'all':
            return random.choices(self.trajectories, k=len(self.trajectories))

    @property
    def q(self):
        return sum(self.quality) / len(self.quality) if len(self.quality) > 0 else 0

    def __getitem__(self, index):
        return self.trajectories[index]
    
    def __len__(self):
        return len(self.trajectories)
