import datetime
from enum import Enum
import numpy as np
import torch as th
import torch.nn.functional as F
from ml_logger import logger

from .arrays import to_torch
from .imbalance import check_imbalance_parallel


class PathStatus(Enum):
    IN_PROGRESS = 0
    COMPLETED = 1
    ABANDONED = 2


def check_state_pair(args, trainer, state, state_prime, discrete_action, threshold):
    with th.no_grad():
        obs_comb = th.concat([state, state_prime], dim=-1)  # (bs, na, 2 * obs_dim)
        action = trainer.ema_model.inv_model(obs_comb.reshape(obs_comb.shape[0], -1)).reshape(obs_comb.shape[0], obs_comb.shape[1], -1)  #  (bs, na, output_dim)
        # action should be converted from logits into onehot encoding
        if discrete_action:
            indices = th.argmax(action, dim=-1)
            action_onehot = F.one_hot(indices, num_classes=action.size(-1)).float()
            ao_comb = th.concat([action_onehot, state], dim=-1)
        else:
            ao_comb = th.concat([action, state], dim=-1)  # (bs, na, n_actions + obs_dim)
        state_prime_recon = trainer.ema_model.fwd_model(ao_comb.reshape(ao_comb.shape[0], -1)).reshape(ao_comb.shape[0], ao_comb.shape[1], -1)  # (bs, na, obs_dim)
        recon_error = th.mean((state_prime - state_prime_recon) ** 2)
        valid = recon_error < threshold
        reward = trainer.ema_model.rwd_model(ao_comb.reshape(ao_comb.shape[0], -1)).reshape(ao_comb.shape[0], 1, 1).repeat(1, ao_comb.shape[1], 1)

        action = action.unsqueeze(1).detach()
        if discrete_action:
            action = th.argmax(action, dim=-1, keepdim=True)

    # print(recon_error.item())
    # if not valid:
    #     print(state_prime)
    #     print(state_prime_recon)
    return valid, action, reward

def generate_from_obs(args, obs, trainer, dataset, num_episodes, n_agents, cond_return):
    returns = (cond_return * th.ones(num_episodes, 1, n_agents)).to(args.device)  
    obs = to_torch(obs, device=args.device)
    cond_obs = F.pad(obs, (0, 0, 0, 0, 0, dataset.horizon - 1))
    cond_masks = (th.arange(0, dataset.horizon) < 1).reshape(1, dataset.horizon, 1, 1).repeat(cond_obs.shape[0], 1, cond_obs.shape[2], cond_obs.shape[3])
    conditions = {"x": to_torch(cond_obs, device=args.device), "masks": to_torch(cond_masks, device=args.device)}
    samples = trainer.ema_model.conditional_sample(conditions, returns=returns)  # (num_episodes, hrz, ag, obs_dim)
    # logger.print(samples)
    return samples

def generate_from_partially_noised_seq_parallel(args, trainer, dataset, n_agents, cond_return, cur_obs_seq, bad_agent):
    num_episodes = len(cur_obs_seq)
    returns = (cond_return * th.ones(num_episodes, 1, n_agents)).to(args.device)
    # cur_obs_seq is a list of (seq, na, dim). cond_obs_start (bs, 1, na, dim)
    cond_obs = np.zeros((len(cur_obs_seq), dataset.horizon, cur_obs_seq[0].shape[1], cur_obs_seq[0].shape[2]))
    cond_masks = np.ones_like(cond_obs, dtype=bool)  # (bs, hrz, na, dim)
    for i, (array, agent_arr) in enumerate(zip(cur_obs_seq, bad_agent)):
        cond_obs[i, :array.shape[0], :, :] = array
        cond_masks[i, :, agent_arr, :] = False
    cond_masks[:, 0, :, :] = True
    conditions = {"x": to_torch(cond_obs, device=args.device), "masks": to_torch(cond_masks, device=args.device)}
    samples = trainer.ema_model.conditional_sample(conditions, returns=returns)  # (num_episodes, hrz, ag, obs_dim)
    return samples


def trajectory_stitching_parallel(
        args, 
        n_gen, 
        gen_batch_size, 
        trainer, 
        dataset, 
        horizon, 
        n_agents, 
        discrete_action, 
        cond_return, 
        times_of_regen_upper_limit, 
        total_times_of_regen_upper_limit, 
        partially_noise, 
        threshold, 
        each_ig_step_num, 
        verbose=False
    ):

    soft_n_gen = n_gen + gen_batch_size
    path_status = [PathStatus.IN_PROGRESS] * soft_n_gen
    times_of_regen_list = [0] * soft_n_gen
    total_times_of_regen_list = [0] * soft_n_gen

    sampled_ep_idx = np.random.choice(dataset.n_episodes, soft_n_gen)
    sampled_obs = dataset.fields.normed_observations[sampled_ep_idx]
    init_obs = sampled_obs[:, 0:1, :, :]  # (n_gen, 1, ag, obs_dim)
    seq_len = dataset.max_path_length
    gen_obs_seq = np.hstack([init_obs, np.zeros((soft_n_gen, seq_len, n_agents, dataset.observation_dim))])  # (n_gen, seq_len + 1, ag, obs_dim)
    gen_action_seq = np.zeros((soft_n_gen, seq_len, n_agents, dataset.action_dim))
    gen_reward_seq = np.zeros((soft_n_gen, seq_len, n_agents, 1))
    cur = [0] * soft_n_gen

    start_time = datetime.datetime.now()
    while len([x for x in path_status if x == PathStatus.COMPLETED]) < n_gen:
        batch_idx = [idx for idx, v in enumerate(path_status) if v == PathStatus.IN_PROGRESS][:gen_batch_size] # len(batch_idx) < gen_batch_size may happen
        cond_obs = np.stack([gen_obs_seq[idx][cur[idx]: cur[idx] + 1] for idx in batch_idx], axis=0)  # (bs, 1, ag, obs_dim)
        last_cur = cur[:]  # copy
        seg = generate_from_obs(args, cond_obs, trainer, dataset, len(batch_idx), n_agents, cond_return)  # (bs, hrz, ag, obs_dim)
        for i, idx in enumerate(batch_idx):
            for j in range(1, horizon):
                valid, action, reward = check_state_pair(args, trainer, seg[i: i + 1, j - 1], seg[i: i + 1, j], discrete_action, threshold)
                if valid:
                    cur[idx] += 1
                    times_of_regen_list[idx] = 0
                    gen_obs_seq[idx][cur[idx]] = seg[i: i + 1, j: j + 1].cpu()  # seg[i: i + 1, j: j + 1]: (1, 1, ag, obs_dim)
                    gen_action_seq[idx][cur[idx] - 1] = action.cpu()
                    gen_reward_seq[idx][cur[idx] - 1] = reward.cpu()
                    if cur[idx] >= seq_len:
                        path_status[idx] = PathStatus.COMPLETED
                        break
                else:
                    times_of_regen_list[idx] += 1
                    total_times_of_regen_list[idx] += 1
                    break
            if times_of_regen_list[idx] >= times_of_regen_upper_limit or total_times_of_regen_list[idx] >= total_times_of_regen_upper_limit:
                path_status[idx] = PathStatus.ABANDONED
        
        if partially_noise and sum(cur) - sum(last_cur) > 0:
            if verbose:
                logger.print("AFTER SEG GEN")
                logger.print(f"cur: {cur}")
                logger.print(f"last_cur: {last_cur}")
            batch_idx_not_abandoned = [idx for idx in batch_idx if path_status[idx] != PathStatus.ABANDONED]
            if verbose:
                logger.print(f"batch_idx_not_abandoned: {batch_idx_not_abandoned}")
            obs_balance_to_be_checked = [gen_obs_seq[idx, last_cur[idx]: cur[idx] + 1] for idx in batch_idx_not_abandoned]
            action_balance_to_be_checked = [gen_action_seq[idx, last_cur[idx]: cur[idx] + 1] for idx in batch_idx_not_abandoned]
            bad_agent = check_imbalance_parallel(args, trainer, obs_balance_to_be_checked, action_balance_to_be_checked, discrete_action, each_ig_step_num)
            if verbose:
                logger.print(f"bad_agent: {bad_agent}")

            bad_agent_i = [i for i, bad_agent_list in enumerate(bad_agent) if bad_agent_list]
            if verbose:
                logger.print(f"bad_agent_i: {bad_agent_i}")
            batch_idx_imbalanced = [batch_idx_not_abandoned[i] for i in bad_agent_i]
            if verbose:
                logger.print(f"batch_idx_imbalanced: {batch_idx_imbalanced}")
            bad_agent = [bad_agent[i] for i in bad_agent_i]
            if verbose:
                logger.print(f"bad_agent: {bad_agent}")
            if batch_idx_imbalanced:
                imbalanced_obs_seq = [obs_balance_to_be_checked[i] for i in bad_agent_i]
                seg = generate_from_partially_noised_seq_parallel(args, trainer, dataset, n_agents, cond_return, imbalanced_obs_seq, bad_agent)

                gen_seg_len = [len(obs) for obs in imbalanced_obs_seq]
                if verbose:
                    logger.print(f"gen_seg_len: {gen_seg_len}")
                for idx in batch_idx_imbalanced:
                    cur[idx] = last_cur[idx]
                if verbose:
                    logger.print(f"cur: {cur}")
                for i in range(seg.shape[0]):
                    for j in range(1, gen_seg_len[i]):
                        valid, action, reward = check_state_pair(args, trainer, seg[i: i + 1, j - 1], seg[i: i + 1, j], discrete_action, threshold)
                        idx = batch_idx_imbalanced[i]
                        if valid:
                            cur[idx] += 1
                            times_of_regen_list[idx] = 0
                            gen_obs_seq[idx][cur[idx]] = seg[i: i + 1, j: j + 1].cpu()  # seg[i: i + 1, j: j + 1]: (1, 1, ag, obs_dim)
                            gen_action_seq[idx][cur[idx] - 1] = action.cpu()
                            gen_reward_seq[idx][cur[idx] - 1] = reward.cpu()
                        else:
                            path_status[idx] = PathStatus.IN_PROGRESS
                            break
                if verbose:
                    logger.print(f"AFTER PARTIALLY DENOISING cur: {cur}")
        
        cm, ab, ip = tuple(path_status.count(x) for x in (PathStatus.COMPLETED, PathStatus.ABANDONED, PathStatus.IN_PROGRESS))
        logger.print(f"COMPLETED {cm}/{soft_n_gen}\tABANDONED {ab}/{gen_batch_size}\tIN_PROCESS {ip}/{soft_n_gen}\t\
                        Time passed: {str(datetime.datetime.now() - start_time).split('.')[0]}.\n")

        abandoned_idx = [idx for idx in range(soft_n_gen) if path_status[idx] == PathStatus.ABANDONED]
        for idx in abandoned_idx:
            cur[idx] = 0
            times_of_regen_list[idx] = 0
            total_times_of_regen_list[idx] = 0
            sampled_ep_idx = np.random.choice(dataset.n_episodes, 1)
            sampled_obs = dataset.fields.normed_observations[sampled_ep_idx]
            init_obs = sampled_obs[0, 0, :, :]  # (1, 1, ag, obs_dim)
            gen_obs_seq[idx][0] = init_obs
            path_status[idx] = PathStatus.IN_PROGRESS     

    completed_idx = np.array([i for i, v in enumerate(path_status) if v == PathStatus.COMPLETED])[:n_gen]
    return gen_obs_seq[completed_idx, :-1, :, :], gen_action_seq[completed_idx], gen_reward_seq[completed_idx]
