import os
import re
import gym
import json
import torch
import importlib
import numpy as np
from tqdm import tqdm
from diffusion_human_ai.ldm.ddpm import DDPMSampler
from diffusion_human_ai.ldm.diffusion import Diffusion
from trainer import generate_env, gen_fixed, DIVERSE_COORPERATION_STYLE_LIST, DIVERSE_ORDERS_STYLE_LIST, CENTER_POTS_STYLE_LIST, CROSSWAY_STYLE_LIST

### Diffusion ###
def get_time_embeddings(timesteps: torch.Tensor):
    freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
    x = timesteps.unsqueeze(-1) * freqs
    return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)

def get_prompt_embeddings(prompt: torch.Tensor, dim=320):
    if not isinstance(prompt, torch.Tensor):
        prompt = torch.tensor(prompt)
    freqs = torch.pow(10000, -torch.arange(start=0, end=(dim // 2), dtype=torch.float32) / (dim // 2))
    x = prompt.unsqueeze(-1) * freqs
    return  torch.cat([torch.cos(x), torch.sin(x)], dim=-1)

def get_prompt_embeddings_simple(prompt: torch.Tensor):
    return prompt.unsqueeze(1)

def process_infos(info_list, delta=0.01):
    relevant_infos = np.array(info_list)[:, np.sum(info_list, axis=0) > 0]
    info_ceilings = np.max(relevant_infos, axis=0)
    relevant_infos = np.clip(np.array(relevant_infos) / info_ceilings, 0, 1 - delta)
    return relevant_infos.astype(np.float32)


def process_infos_with_ranking(info_list, delta=0.01, n_rank=10, preprocess=True, embed=True):
    if preprocess:
        relevant_infos = process_infos(info_list, delta=delta)
    else:
        relevant_infos = np.array(info_list)

    seq_len = len(relevant_infos[0])
    rank_array = np.arange(0, seq_len) * n_rank
    processed_infos = np.round(relevant_infos * n_rank) + rank_array
    if embed:
        return get_prompt_embeddings(processed_infos).cpu().numpy().astype(np.float32) #(n, seq_len, 32)
    return processed_infos


### VAE ###
def get_train_data(args):
    with open(args.desc_load, 'r') as f:
        desc_dict = json.load(f)['train']
    
    idx2env = []
    idx2ego = []
    for i, name in enumerate(desc_dict):
        env_config = args.env_config
        if env_config["layout_name"] == "diverse_coordination":
            masked_events = DIVERSE_COORPERATION_STYLE_LIST[i]
            env_config["masked_events"] = masked_events
        elif env_config["layout_name"] == "diverse_orders":
            masked_events = DIVERSE_ORDERS_STYLE_LIST[i]
            env_config["masked_events"] = masked_events
        elif env_config['layout_name'] == "center_pots":
            masked_events = CENTER_POTS_STYLE_LIST[i]
            env_config["masked_events"] = masked_events
        elif env_config['layout_name'] == "crossway":
            masked_events = CROSSWAY_STYLE_LIST[i]
            env_config["masked_events"] = masked_events
            
        env = gym.make(args.env, **env_config)
        partner_load = os.path.join(args.policy_dir, name + ".zip")
        partner = gen_fixed({}, "PPO", partner_load)
        
        idx = re.findall('[0-9]', name)[0]
        ego_load = os.path.join(args.policy_dir, f"ego_{idx}.zip")
        ego = gen_fixed({}, "PPO", ego_load)
        
        env.add_partner_agent(partner)
        
        idx2env.append(env)
        idx2ego.append(ego)
        
    label_list = list(range(len(idx2ego)))
    return idx2ego, idx2env, label_list


def get_train_data_assistive(args):
    with open(args.desc_load, 'r') as f:
        desc_dict = json.load(f)

    task2env = {"drinking": "DrinkingJacoHuman-v1", "feeding": "FeedingJacoHuman-v1",
                "scratchitch": "ScratchItchJacoHuman-v1", "bedbathing": "BedBathingJacoHuman-v1"} 

    idx2env = []
    idx2ego = []
    idx2partner = []

    module = importlib.import_module('assistive_gym.envs')
    for i, task_name in enumerate(desc_dict):
        env_name = task2env[task_name].split('-')[0]
        env_class = getattr(module, env_name.split('-')[0] + 'Env')

        env = env_class()

        partner_load = os.path.join(args.policy_dir, task_name + "_alt.zip")
        partner = gen_fixed({}, "PPO", partner_load)
        
        ego_load = os.path.join(args.policy_dir, task_name + "_ego.zip")
        ego = gen_fixed({}, "PPO", ego_load)
        
        idx2env.append(env)
        idx2ego.append(ego)
        idx2partner.append(partner)
        
    label_list = list(range(len(idx2ego)))
    return idx2ego, idx2partner, idx2env, label_list

def get_train_data_lbf(args):
    with open(args.desc_load, 'r') as f:
        desc_dict = json.load(f)
        
    idx2env = []
    idx2ego = []
    for i, name in enumerate(desc_dict):
        env_config = args.env_config
        if env_config["layout_name"] == "lbf_spread":
            target_food_level = re.findall('[0-9]', name)[0]
            env_config["layout_name"] = f"lbf_spread_{target_food_level}-v0"
            
        env = gym.make(args.env, **env_config)
        partner_load = os.path.join(args.policy_dir, name + ".zip")
        partner = gen_fixed({}, "PPO", partner_load)
        
        idx = re.findall('[0-9]', name)[0]
        ego_load = os.path.join(args.policy_dir, f"ego_{idx}.zip")
        ego = gen_fixed({}, "PPO", ego_load)
        
        env.add_partner_agent(partner)
        
        idx2env.append(env)
        idx2ego.append(ego)
        
    label_list = list(range(len(idx2ego)))
    return idx2ego, idx2env, label_list


### Diffusion ### 

def generate_params(prompt_list, vae, generator, diffusion, device):
    params_list = []
    vae.eval()
    diffusion.eval()
    with torch.no_grad():
 
        for prompt in prompt_list:
            context = get_prompt_embeddings_simple(torch.tensor([prompt] * 8)).to(device)
            sampler = DDPMSampler(generator)
            sampler.set_inference_timesteps()
            
            latents_shape = (8, 1, 8, 8) 
            latents = torch.randn(latents_shape, device=device)
            
            for timestep in tqdm(sampler.timesteps):
                time_embedding = get_time_embeddings(timestep).to(device)
                
                model_input = latents
                model_output = diffusion(model_input, context, time_embedding) # predict the noise
                
                latents = sampler.step(timestep, latents, model_output) # denoising process
                
            latents = latents.reshape(8, -1) # denoised latents
            recon_params = vae.decoder(latents) # reconstruct the parameters with vae.decoder
            params_list.append(recon_params[0])
    return params_list
