ALGO_NAME = 'DrS-learn-reward'

import argparse
import random
import time
import yaml
import sys
from pathlib import Path
from distutils.util import strtobool
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from torch import nn, optim
from copy import deepcopy
import gymnasium as gym

from utils import recursive_objectify

from env import make_env, CustomVectorEnv
from sac import SAC, CycleSAC
from ldba import make_automaton
from logger import Logger
import json

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

def fetch_align_state_map(arr):
    mapping = {0: 0, 2: 1, 1: 2, 3: 3}
    
    return np.vectorize(mapping.get)(arr)

def fetch_align_tensor_map(tensor):
    remapped_tensor = torch.zeros_like(tensor)
    
    remapped_tensor[tensor == 0] = 0
    remapped_tensor[tensor == 2] = 1
    remapped_tensor[tensor == 1] = 2
    remapped_tensor[tensor == 3] = 3
    
    return remapped_tensor

def flat_cycle_state_map(arr):
    mapping = {-1: -1, 0: 0, 2: 1, 1: 2, 3: 3}
    
    return np.vectorize(mapping.get)(arr)

def flat_cycle_tensor_map(tensor):
    remapped_tensor = torch.zeros_like(tensor)
    
    remapped_tensor[tensor == -1] = -1
    remapped_tensor[tensor == 0] = 0
    remapped_tensor[tensor == 2] = 1
    remapped_tensor[tensor == 1] = 2
    remapped_tensor[tensor == 3] = 3

    return remapped_tensor

def parse_args(config_file="config.json"):
    parser = argparse.ArgumentParser(description="Experiment Arguments")

    # Add a command-line argument for the config file
    parser.add_argument("--config", type=str, default=None, help="Path to the configuration file")

    # Parse known arguments first to check if a config file is specified
    args, unknown = parser.parse_known_args()

    # Load configuration from the specified file if provided
    config = {}
    if args.config:
        with open(args.config, "r") as f:
            config = json.load(f)

    # Basic arguments
    parser.add_argument('--exp_name', type=str, default=config["exp_name"], help='Experiment name')
    parser.add_argument('--seed', type=int, default=config["seed"], help='Random seed')
    parser.add_argument('--device', type=str, default=config["device"], help='Device to use (e.g., "cpu", "cuda")')
    parser.add_argument('--working_dir', type=str, default=config["working_dir"], help='Working directory')
    parser.add_argument('--log_frequency', type=int, default=config["log_frequency"], help='Log frequency')

    # Environment settings
    parser.add_argument('--env_id', type=str, default=config["env_id"], help='Environment ID')
    parser.add_argument('--env_kwargs', type=json.loads, default=json.dumps(config["env_kwargs"]), help='Environment kwargs (in JSON format)')

    # Algorithm settings
    parser.add_argument('--algo', type=str, default=config["algo"], help='Algorithm to use')
    parser.add_argument('--num_envs', type=int, default=config["num_envs"], help='Number of environments')
    parser.add_argument('--num_steps', type=int, default=config["num_steps"], help='Number of steps per environment')
    parser.add_argument('--learning_starts', type=int, default=config["learning_starts"], help='Number of steps before learning starts')
    parser.add_argument('--total_timesteps', type=int, default=config["total_timesteps"], help='Total timesteps for training')
    parser.add_argument('--buffer_size', type=int, default=config["buffer_size"], help='Size of the replay buffer')

    # Hyperparameters
    parser.add_argument('--gamma', type=float, default=config["gamma"], help='Discount factor')
    parser.add_argument('--learning_rate', type=float, default=config["learning_rate"], help='Learning rate')
    parser.add_argument('--batch_size', type=int, default=config["batch_size"], help='Batch size')
    parser.add_argument('--target_network_frequency', type=int, default=config["target_network_frequency"], help='Frequency of target network updates')
    parser.add_argument('--tau', type=float, default=config["tau"], help='Soft update coefficient')
    parser.add_argument('--alpha', type=float, default=config["alpha"], help='Alpha for entropy regularization')
    parser.add_argument('--autotune', type=lambda x: bool(strtobool(str(x))), default=config["autotune"], help='Whether to autotune alpha')

    # Policy and Q-function settings
    parser.add_argument('--policy_frequency', type=int, default=config["policy_frequency"], help='Frequency of policy updates')
    parser.add_argument('--q_lr', type=float, default=config["q_lr"], help='Learning rate for Q-function')
    parser.add_argument('--q_epsilon', type=float, default=config["q_epsilon"], help='Epsilon for exploration in Q-function')
    parser.add_argument('--q_eventual', type=lambda x: bool(strtobool(str(x))), default=config["q_eventual"], help='Whether Q-learning is eventual')

    # Optimizer settings
    parser.add_argument('--opt_coeff', type=float, default=config["opt_coeff"], help='Coefficient for optimization')
    parser.add_argument('--opt_staleness', type=int, default=config["opt_staleness"], help='Staleness tolerance for optimization')
    parser.add_argument('--opt_init', type=int, default=config["opt_init"], help='Initial value for optimization')
    parser.add_argument('--opt_thompson_samples', type=int, default=config["opt_thompson_samples"], help='Number of Thompson samples')

    parser.add_argument("--demo_path", type=str, default=config["demo_path"], help="The path of the demo file")

    parser.add_argument("--next_q_value", type=int, default=config["next_q_value"], help="Next Q-value")

    parser.add_argument("--disc_lr", type=float, default=config["disc_lr"], help="The learning rate of the discriminator optimizer")
    parser.add_argument("--disc_frequency", type=int, default=config["disc_frequency"], help="The frequency of training discriminator (delayed)")
    parser.add_argument("--disc_th", type=float, default=config["disc_th"], help="The success rate threshold for early stopping discriminator training")
    
    args = parser.parse_args()
    return args

class Discriminator(nn.Module):
    def __init__(self, envs, n_stages):
        super().__init__()
        self.n_stages = n_stages
        state_shape = np.prod(envs.single_observation_space.shape)
        self.nets = nn.ModuleList([
            nn.Sequential(
                nn.Linear(state_shape, 32),
                nn.Sigmoid(),
                nn.Linear(32, 1),
            ) for _ in range(n_stages)
        ])
        self.trained = [False] * n_stages

    def set_trained(self, stage_idx):
        self.trained[stage_idx] = True

    def forward(self, next_s, stage_idx):
        next_s = next_s.to(torch.float32)
        net = self.nets[stage_idx]
        return net(next_s)

    def get_reward(self, next_s, success, infos):
        next_s = next_s.to(torch.float32)
        with torch.no_grad():
            bs = next_s.shape[0]
            if not torch.is_tensor(success):
                success = torch.tensor(success, device=next_s.device)
                success = success.reshape(bs, 1)
            if self.n_stages > 1:
                stage_idx = infos.squeeze()
                pro_stage_idx = stage_idx.clone()
                pro_stage_idx[pro_stage_idx == -1] = 0 # for sink state
            else:
                stage_idx = success.squeeze(-1)
            stage_rewards = [
                torch.tanh(self(next_s, stage_idx=i)) if self.trained[i] else torch.zeros(bs, 1, device=next_s.device)
            for i in range(self.n_stages)]
            stage_rewards = torch.cat(stage_rewards + [torch.zeros(bs, 1, device=next_s.device)], dim=1)

            k = 3 # alpha = 1/3
            reward = k * stage_idx + stage_rewards[torch.arange(bs), pro_stage_idx.long()]
            reward = reward / (k * self.n_stages) # reward is in (0, 1]

            return reward
        
    def classify_demo(self, next_s):
        next_s = next_s.to(torch.float32)
        with torch.no_grad():
            reward = self(next_s, 0).squeeze(1)

            return reward

class DiscriminatorBuffer(object):
    # can be optimized by create a buffer of size (n_traj, len_traj, dim)
    def __init__(self, buffer_size, obs_space, device):
        self.buffer_size = buffer_size
        self.next_observations = np.zeros((self.buffer_size,) + (obs_space.shape[0], ), dtype=obs_space.dtype)
        self.device = device
        self.pos = 0
        self.full = False

    @property
    def size(self) -> int:
        return self.buffer_size if self.full else self.pos

    def add(self, next_obs):
        l = next_obs.shape[0]
        
        while self.pos + l >= self.buffer_size:
            self.full = True
            k = self.buffer_size - self.pos
            self.next_observations[self.pos:] = next_obs[:k]
            self.pos = 0
            next_obs = next_obs[k:]
            l = next_obs.shape[0]
            
        self.next_observations[self.pos:self.pos+l] = next_obs.copy()
        self.pos = (self.pos + l) % self.buffer_size

    def sample(self, batch_size):
        idxs = np.random.randint(0, self.size, size=batch_size)
        batch = dict(
            next_observations=self.next_observations[idxs],
        )
        return {k: torch.tensor(v).to(self.device) for k,v in batch.items()}

def sample_from_multi_buffers(buffers, batch_size):
    # Warning: when the buffers are full, this will make samples not uniform
    sizes = [b.size for b in buffers]
    tot_size = sum(sizes)
    if tot_size == 0:
        raise Exception('All buffers are empty!')
    n_samples = [int(s / tot_size * batch_size) for s in sizes]
    if sum(n_samples) < batch_size:
        n_samples[np.argmax(sizes)] += batch_size - sum(n_samples)
    batches = []
    for b, n in zip(buffers, n_samples):
        if n > 0:
            if b.size == 0:
                raise Exception('Buffer is empty!')
            batches.append(b.sample(n))
    ret = {}
    for k in batches[0].keys():
        ret[k] = torch.cat([b[k] for b in batches], dim=0)
    return ret

def eval_episode(algo, envs, seed, device, args):
    to_tensor = lambda x: torch.tensor(x, device=device, dtype=torch.float32)
    dict_to_tensor = lambda d: {k: to_tensor(v) for k, v in d.items() if 'final' not in k}
    obs, info = envs.reset(seed=seed)
    obs, info = to_tensor(obs), dict_to_tensor(info)
    for _ in range(1000):
        with torch.no_grad():
            action = algo.eval_action(obs, info=info)
        next_obs, _, _, _, next_info = envs.step(action.cpu().numpy())
        if "final_info" in next_info:
            return {'eval_return': next_info['final_info'][0]['episode']['r']}
        obs, info = to_tensor(next_obs), dict_to_tensor(next_info)
    return {'eval_return': -1.}


def get_true_final(next_obs, next_info, next_done):
    real_next_obs = next_obs.copy()
    real_next_info = deepcopy(next_info)
    if "final_info" in next_info:        
        for idx, done in enumerate(next_done):
            if done:
                real_next_obs[idx] = next_info['final_observation'][idx]
                for k in [k for k in next_info['final_info'][0] if k in real_next_info]:
                    real_next_info[k][idx] = next_info['final_info'][idx][k]
    return real_next_obs, real_next_info


def main(**kwargs):
    args = parse_args()
    for seed in range(10):
        args.seed = seed
        current_time = int(time.time())  # Get the current time as an integer for readability
        args.working_dir = f"results/{args.env_id}/{args.algo}/{args.seed}/{current_time}"

        to_tensor = lambda x: torch.tensor(x, device=args.device, dtype=torch.float32)
        dict_to_tensor = lambda d: {k: to_tensor(v) for k, v in d.items() if 'final' not in k}
        to_zeros = lambda x: torch.zeros(x, device=args.device)
        logger = Logger(args)
        [f(args.seed) for f in [random.seed, np.random.seed, torch.manual_seed]]
        torch.backends.cudnn.deterministic = True
        automaton = make_automaton(args.env_kwargs['formula'])
        accepting_node = [node for node, data in automaton.graph.nodes(data=True) if data.get('accepting', False)][0]
        n_stages = accepting_node

        envs = CustomVectorEnv([make_env(args.env_id, args.gamma, args.env_kwargs) for _ in range(args.num_envs)])
        eval_envs = CustomVectorEnv([make_env(args.env_id, args.gamma, args.env_kwargs) for _ in range(args.num_envs)])
        envs.single_action_space.seed(args.seed)
        eval_envs.single_action_space.seed(args.seed)
        logger.log_ldba(envs)

        # Instantiate the learning algorithm
        if args.env_kwargs['formula'].startswith("GF"):
            algo = CycleSAC(envs, args)
        else:
            algo = SAC(envs, args)

        # Instantiate the Discriminator
        disc = Discriminator(envs, n_stages).to(args.device)
        disc_optimizer = optim.Adam(disc.parameters(), lr=args.disc_lr)
        disc_training = [True] * n_stages
        
        stage_buffers = [DiscriminatorBuffer(
            args.buffer_size,
            envs.single_observation_space,
            args.device,
        ) for _ in range(n_stages + 2)]

        demo_data = np.load(args.demo_path, allow_pickle=True)
        
        key_lengths = {key: demo_data[key].shape[0] for key in demo_data.keys()}
        sorted_keys = sorted(key_lengths, key=key_lengths.get, reverse=True)
        top_5_keys = sorted_keys[:5]
        limited_demo_data = {key: demo_data[key] for key in top_5_keys}

        if args.env_id == "LTLCheetahFrontround-v0":
            first_5_keys = list(demo_data.keys())[:5]
            limited_demo_data = {key: demo_data[key][:100] for key in first_5_keys}

        values = [value for value in limited_demo_data.values()]
        values = np.concatenate(values, axis=0)
        
        stage_buffers[n_stages].add(values)
        
        if args.env_id == "LTLFlatRound-v0":
            observations_dict = np.load("data/flatround_demo/observations_50.pkl", allow_pickle=True)
            next_observations_dict = np.load("data/flatround_demo/next_observations_50.pkl", allow_pickle=True)
            actions_dict = np.load("data/flatround_demo/actions_50.pkl", allow_pickle=True)
            infos_dict = np.load("data/flatround_demo/infos_50.pkl", allow_pickle=True)
            next_infos_dict = np.load("data/flatround_demo/next_infos_50.pkl", allow_pickle=True)
            rewards_dict = np.load("data/flatround_demo/rewards_50.pkl", allow_pickle=True)
            next_dones_dict = np.load("data/flatround_demo/next_dones_50.pkl", allow_pickle=True)
            
            obss = np.concatenate([value for value in observations_dict.values()], axis=0)
            infos = np.concatenate([value for value in infos_dict.values()], axis=0)
            actions = np.concatenate([value for value in actions_dict.values()], axis=0)
            rewards = np.concatenate([value for value in rewards_dict.values()], axis=0)
            real_next_obss = np.concatenate([value for value in next_observations_dict.values()], axis=0)
            real_next_infos = np.concatenate([value for value in next_infos_dict.values()], axis=0)
            next_dones = np.concatenate([value for value in next_dones_dict.values()], axis=0)
            
            obss = torch.tensor(obss).unsqueeze(1).to(args.device)
            real_next_obss = torch.tensor(real_next_obss).unsqueeze(1).to(args.device)
            actions = torch.tensor(actions).unsqueeze(1).to(args.device)
            rewards = torch.tensor(rewards).unsqueeze(1).to(args.device).float()
            next_dones = torch.tensor(next_dones).to(args.device).float()
            infos = [{'ldba_obs': torch.tensor([val], dtype=torch.float32, device=args.device)} for val in infos]
            real_next_infos = [{'ldba_obs': torch.tensor([val], dtype=torch.float32, device=args.device)} for val in real_next_infos]

            for i_b in range(len(obss)):
                algo.store(obss[i_b], infos[i_b], False, actions[i_b], rewards[i_b], 0, 0,
                        real_next_obss[i_b], real_next_infos[i_b], next_dones[i_b], 0)

        start_time = time.time()
        obs, info = envs.reset(seed=args.seed)
        stage_indices = []
        traj_obs = []
        obs, info = to_tensor(obs), dict_to_tensor(info)
        done = to_zeros(args.num_envs)

        for step in range(args.total_timesteps):
                
            if step < args.learning_starts:
                action = torch.tensor(np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]))
                if isinstance(envs.single_action_space, gym.spaces.Discrete): 
                    action = action.unsqueeze(-1)
                logprob, value = (action * 0).sum(-1), (action * 0).sum(-1, keepdims=True)
            else:
                with torch.no_grad():
                    action, logprob, _, value = algo.get_action_and_value(obs, info=info)
            
            next_obs, reward, _, next_done, next_info = envs.step(action.cpu().numpy())

            if "final_info" in next_info:
                [logger.log_env(inf, step) for inf in next_info['final_info']]
            real_next_obs, real_next_info = get_true_final(next_obs, next_info, next_done)
            stage_indices.append(real_next_info['ldba_obs'].item())
            traj_obs.append(real_next_obs[0])

            if "final_info" in next_info:
                if args.env_kwargs['formula'].startswith("GF"):
                    traj_obs = np.array(traj_obs)
                    seg_indices = np.where(np.array(stage_indices) == n_stages)[0]
                    indices_segments = np.split(stage_indices, seg_indices + 1)
                    traj_segments = np.split(traj_obs, seg_indices + 1)

                    indices_segments = [seg for seg in indices_segments if len(seg) > 0]
                    traj_segments = [seg for seg in traj_segments if len(seg) > 0]
            
                    for i, indice_segment in enumerate(indices_segments):
                        l = len(indice_segment)
                        if -1 in indice_segment:
                            best_step = l - 1
                            stage_idx = n_stages + 1
                        else:
                            best_step = l - 1 - np.argmax(indice_segment[::-1])
                            stage_idx = int(indice_segment[best_step])
                        
                        traj_segments[i] = traj_segments[i][:best_step+1]
                    
                        stage_buffers[stage_idx].add(traj_segments[i])
                    
                    stage_indices = []
                    traj_obs = []
                else:
                    traj_obs = np.array(traj_obs)
                    stage_indices = np.array(stage_indices)
                    if args.env_id == "LTLFetchAlign-v0":
                        stage_indices = fetch_align_state_map(stage_indices)
                    elif args.env_id == "LTLFlatRound-v0":
                        stage_indices = flat_cycle_state_map(stage_indices)
                    l = len(stage_indices)
                    best_step = l - 1 - np.argmax(stage_indices[::-1])
                    stage_idx = int(stage_indices[best_step])
                    
                    traj_obs = traj_obs[:best_step+1]

                    stage_buffers[stage_idx].add(traj_obs)

                    stage_indices = []
                    traj_obs = []

            next_obs, real_next_obs, next_done, reward = (to_tensor(e) for e in [next_obs, real_next_obs, next_done, reward])
            next_info, real_next_info = dict_to_tensor(next_info), dict_to_tensor(real_next_info)
            
            if args.env_id == "LTLFetchAlign-v0":
                mapped_real_next_info = fetch_align_tensor_map(real_next_info['ldba_obs'])
                disc_reward = disc.get_reward(real_next_obs, reward, mapped_real_next_info)
            elif args.env_id == "LTLFlatRound-v0":
                mapped_real_next_info = flat_cycle_tensor_map(real_next_info['ldba_obs'])
                disc_reward = disc.get_reward(real_next_obs, reward, mapped_real_next_info)
            else:
                disc_reward = disc.get_reward(real_next_obs, reward, real_next_info['ldba_obs'])

            algo.store(obs, info, done, action, disc_reward, logprob, value.flatten(),
                    real_next_obs, real_next_info, next_done, step % args.num_steps)
            logger.store(obs)
            obs, info, done = next_obs, next_info, next_done

            if step < args.learning_starts:
                continue

            # Train the discriminator
            if (step + 1) % args.disc_frequency == 0:
                for stage_idx in range(n_stages): 
                    if not disc_training[stage_idx]:
                        continue
                    success_data = sample_from_multi_buffers(stage_buffers[stage_idx+1:-1], args.batch_size)['next_observations']
                    fail_data = sample_from_multi_buffers(stage_buffers[:stage_idx+1] + [stage_buffers[-1]], args.batch_size)['next_observations']

                    disc_next_obs = torch.cat([fail_data[:, :envs.single_observation_space.shape[0]], 
                                                success_data[:, :envs.single_observation_space.shape[0]]], dim=0)
                    disc_labels = torch.cat([
                        torch.zeros((fail_data.shape[0], 1), device=args.device), # fail label is 0
                        torch.ones((success_data.shape[0], 1), device=args.device), # success label is 1
                    ], dim=0)
                    
                    logits = disc(disc_next_obs, stage_idx)
                    disc_loss = F.binary_cross_entropy_with_logits(logits, disc_labels)
                    
                    disc_optimizer.zero_grad()
                    disc_loss.backward()
                    disc_optimizer.step()

                    pred = logits.detach() > 0

                    disc.set_trained(stage_idx)

            # learning update
            if (step + 1) % algo.update_frequency == 0:
                algo.anneal(step, args.total_timesteps)
                train_metrics = algo.train(step, args.next_q_value)
            
            if (step + 1) % args.log_frequency == 0:
                train_metrics['SPS'] = int(step / (time.time() - start_time))
                logger.log_train(train_metrics, step)
                logger.log_eval(eval_episode(algo, eval_envs, seed=step, device=args.device, args=args), step)

        logger.log_hparams(args)
        envs.close()
        logger.close()
        save_dict = {
            'agent': algo.agent,
            'discriminator': disc.state_dict()
        }

        torch.save(save_dict, f'{args.working_dir}/model.pt')
        algo.save(args.working_dir)


if __name__ == "__main__":
    main()
