import os
import pickle
from typing import Tuple

import gym
import torch
import numpy as np
from tqdm import tqdm


def get_discount_returns(rewards, discount=1):
    returns = 0
    scale = 1
    for r in rewards:
        returns += scale * r
        scale *= discount
    return returns


def generate_trajectory(
    env: gym.Env, 
    policy: torch.nn.Module,
    start_state: np.ndarray=None,
    qpos: np.ndarray=None,
    qvel: np.ndarray=None,
    device: str='cpu',
    seg_len: int=10,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    
    max_action = float(env.action_space.high[0])

    if start_state is None:
        start_state = env.reset()
    if qpos is None:
        qpos = np.copy(env.data.qpos[:])
    if qvel is None:
        qvel = np.copy(env.data.qvel[:])
    
    _ = env.reset()
    env.set_state(qpos, qvel)

    done = False
    state = np.copy(start_state)
    states, actions, rewards = [], [], []
    while not done:
        with torch.no_grad():
            torch_state = torch.Tensor(state).to(device)
            action = policy(torch_state).clip(-max_action, max_action)
            action = action.cpu().detach().numpy()
        
        next_state, reward, done, _ = env.step(action)
        reward = np.array(reward)
        
        if len(states) > seg_len:
            break
        
        states.append(state)
        actions.append(action)
        rewards.append(reward)
        state = next_state              

    return (
        np.array(states), 
        np.array(actions), 
        np.array(rewards),
        np.array(start_state),
    )


def generate_preference_dataset(
    env: gym.Env, 
    policy: torch.nn.Module, 
    data_path: str='./cache/preference_dataset.pkl',
    num_pairs: int=100, 
    seg_len: int=30,
    max_trials: int=1e6,
    max_start_state: int=250,
    device: str='cpu',
    ) -> torch.utils.data.Dataset:
    
    _dataset = []
    for _ in tqdm(range(num_pairs), desc='Generating Preference'):
        _sucess = False
        while not _sucess:
            
            try:
                max_action = float(env.action_space.high[0])
                random_start = np.random.randint(max_start_state)
                start_state = env.reset()
                for _ in range(random_start):
                    with torch.no_grad():
                        torch_state = torch.Tensor(start_state).to(device)
                        action = policy(torch_state).clip(-max_action, max_action)
                        action = action.cpu().detach().numpy()
                    
                    start_state, _, done, _ = env.step(action)
                    assert not done, 'Random start state is terminal'

                qpos = np.copy(env.data.qpos[:])
                qvel = np.copy(env.data.qvel[:])
            
                traj_0 = generate_trajectory(env, policy, start_state, qpos, qvel, device=device, seg_len=seg_len)
                traj_1 = generate_trajectory(env, policy, start_state, qpos, qvel, device=device, seg_len=seg_len)
                
                s0, a0, r0, x0 = traj_0
                s1, a1, r1, x1 = traj_1
                
                assert len(s0) >= seg_len, f'Trajectory length {s0.shape[0]} is less than seg_len'
                assert len(s1) >= seg_len, f'Trajectory length {s1.shape[0]} is less than seg_len'
                assert (x0 == x1).all(), 'Initial states are different'
                
                g0 = get_discount_returns(r0, discount=1)
                g1 = get_discount_returns(r1, discount=1)
                if g1 < g0:
                    s1, s0 = s0[:seg_len], s1[:seg_len]
                    a1, a0 = a0[:seg_len], a1[:seg_len]
                else:
                    s0, s1 = s1[:seg_len], s0[:seg_len]
                    a0, a1 = a1[:seg_len], a0[:seg_len]
                
                _dataset.append((s1, a1, s0, a0, x1))
                _sucess = True
                
            except AssertionError as e:
                max_trials -= 1
            
            if max_trials == 0:
                break
        
    with open(data_path, 'wb') as f:
        pickle.dump(_dataset, f)
        print('Preference dataset saved at:', data_path)
        
    return _dataset

    