import os
import json
import gym
import sys
import torch
import re
import torch.nn as nn
import argparse
from tqdm import tqdm
import numpy as np
from torch.utils.data import Dataset, DataLoader
from .utils import *
from diffusion_human_ai.ldm.utils import *

from diffusion_human_ai.ldm.ddpm import DDPMSampler
from diffusion_human_ai.ldm.diffusion import Diffusion, PositionEmbedding
from diffusion_human_ai.ghn.core import hyperActor
from diffusion_human_ai.ldm.vae import Conv1dEncoder, HyperDecoder, VAE

from overcooked_ai_py.agents.agent import AgentPair, PantheonRLAgent
from pantheonrl.algos.diffusion_human_ai.translator.vae import Translator
from overcooked_ai_py.mdp.overcooked_mdp import generate_env, EVENT_TYPES
from trainer import DIVERSE_COORPERATION_STYLE_LIST, DIVERSE_ORDERS_STYLE_LIST, \
    CENTER_POTS_STYLE_LIST, CROSSWAY_STYLE_LIST
from tester import gen_fixed

import importlib


def custom_collate(batch):
    policies, prompts = zip(*batch)
    return policies, torch.tensor(np.array(prompts))

class AgentDataset(Dataset):
    def __init__(self, ego_list, prompt_list, args):
        self.ego_list = ego_list
        self.policy_list = [ego.policy for ego in ego_list]
        self.prompt_list = list(prompt_list) if isinstance(prompt_list, np.ndarray) else prompt_list
        
        if args.multi_batch:
            self.policy_list = self.policy_list * args.batch_size
            self.prompt_list = self.prompt_list * args.batch_size 
            
    def __len__(self):
        return len(self.policy_list)
    
    def __getitem__(self, idx):
        policy = self.policy_list[idx]
        prompt = self.prompt_list[idx]
        return policy, prompt



def train_diffusion(args, sampler, diffusion, dataloader, vae, device):
    n_steps = sampler.num_train_timesteps
    loss_fn = nn.MSELoss().to(device)
    optimizer = torch.optim.Adam(diffusion.parameters(), lr=args.lr)
    scheduler =  torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = optimizer, T_max = args.n_epochs)

    vae.eval()
    
    for epoch in range(args.n_epochs):
        epoch_loss = 0
        
        for policy_batch, prompt_batch in dataloader:
            torch.cuda.empty_cache()
        
            params_batch = params_from_policy(policy_batch, obs_dim=obs_dim)

            with torch.no_grad():
                x, _, _ = vae.encoder(params_batch) 
        
            prompt = prompt_batch.unsqueeze(1).to(device) if len(prompt_batch) == 2 else prompt_batch.to(device) 
            
            timesteps = torch.randint(0, n_steps, (args.batch_size, ))
            time_embeddings = get_time_embeddings(timesteps).to(device)
            
            x = x.reshape(x.shape[0], 1, *args.latent_shape)
            x_t, eps = sampler.add_noise(x, timesteps)
            eps_theta = diffusion(x_t, prompt, time_embeddings).reshape(x.shape[0], -1)
            
            loss = loss_fn(eps_theta, eps.reshape(args.batch_size, -1))
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(diffusion.parameters(), 1.)
            optimizer.step()

            epoch_loss += loss.item()
        
        scheduler.step()
        epoch_loss /= len(dataloader)
        print(f'epoch: {epoch + 1} / {args.n_epochs}, loss: {epoch_loss:.6f}')        

        if (epoch + 1) % 100 == 0:
            print('Saving model to {}...'.format(args.diffusion_save))
            torch.save(diffusion.state_dict(), args.diffusion_save)

    print('Done!')
    torch.save(diffusion.state_dict(), args.diffusion_save)
    print("Save model to {}.".format(args.diffusion_save))

    
def get_event_infos(args, ego_list, alt_list, env_list=None, embed=True):
    if not env_list:
        env, altenv = generate_env(args)
        base_env = env.base_env
        
    partner_info_list = []
    ego_info_list = []

    print("Collecting event infos of partner...")

    for i in range(len(ego_list)):
        if env_list:
            env = env_list[i]
            base_env = env.base_env
        
        ego_agent = PantheonRLAgent(ego_list[i], env)
        alt_agent = PantheonRLAgent(alt_list[i], env)
        agent_pair = AgentPair(ego_agent, alt_agent)
        
        base_env.get_rollouts(agent_pair, num_games=args.rollout_games)

        event_list = EVENT_TYPES         
        for i, pos in enumerate(base_env.mdp._get_terrain_type_pos_dict()[' ']):
            pos_key = 'pos_%s_%s' % (pos[0], pos[1])
            event_list.append(pos_key)

        ego_stats = {event: [] for event in event_list}
        partner_stats = {event: [] for event in event_list}
        for game_stats in base_env.game_stats_buffer:
            for event in game_stats:

                if not isinstance(game_stats[event], list): continue
                if 'pos_' in event: continue

                ego_stats[event].append(len(game_stats[event][0]) / args.horizon)
                partner_stats[event].append(len(game_stats[event][1]) / args.horizon)

        partner_info = [np.mean(values) for values in partner_stats.values()]
        partner_info_list.append(partner_info)

        ego_info = [np.mean(values) for values in ego_stats.values()]
        ego_info_list.append(ego_info)

    processed_infos = process_infos_with_ranking(partner_info_list, embed=embed)

    return processed_infos

def get_event_infos_from_translator(args, translator):
    with open(args.desc_load, 'r') as f:
        desc_dict = json.load(f)['train']

    event_infos_dict = {}
    for partner_label in desc_dict:
        desc_list = desc_dict[partner_label]

        with torch.no_grad():
            event_infos = translator.convert(desc_list)
            embed_dim = 32
            seq_len = len(event_infos[0])
            
            position_embed = PositionEmbedding(embed_dim, seq_len)
            processed_infos = position_embed(event_infos.cpu().unsqueeze(-1))

            event_infos_dict[partner_label] = processed_infos.cpu().numpy()
    
    return event_infos_dict

def get_train_data(args):
    with open(args.desc_load, 'r') as f:
        desc_dict = json.load(f)['train'] if args.diverse_desc else json.load(f)
        
    idx2env = []
    idx2ego = []
    idx2partner = []
    for i, name in enumerate(desc_dict):
        env_config = args.env_config
        if env_config["layout_name"] == "diverse_coordination":
            masked_events = DIVERSE_COORPERATION_STYLE_LIST[i]
            env_config["masked_events"] = masked_events
        elif env_config["layout_name"] == "diverse_orders":
            style_id = i % len(DIVERSE_ORDERS_STYLE_LIST)
            masked_events = DIVERSE_ORDERS_STYLE_LIST[style_id]
            env_config["masked_events"] = masked_events
        elif env_config["layout_name"] == "center_pots":
            masked_events = CENTER_POTS_STYLE_LIST[i]
            env_config["masked_events"] = masked_events
        elif env_config["layout_name"] == "crossway":
            masked_events = CROSSWAY_STYLE_LIST[i]
            env_config["masked_events"] = masked_events
            
        env = gym.make(args.env, **env_config)
        partner_load = os.path.join(args.model_load, name + ".zip")
        partner = gen_fixed({}, "PPO", partner_load)
        
        idx = re.findall('[0-9]', name)[0]
        ego_load = os.path.join(args.model_load, f"ego_{idx}.zip")
        ego = gen_fixed({}, "PPO", ego_load)
        
        env.add_partner_agent(partner)
        
        idx2env.append(env)
        idx2ego.append(ego)
        idx2partner.append(partner)

    if translator is not None:
        prompts = get_event_infos_from_translator(args, translator)
    else:
        prompts = get_event_infos(args, idx2ego, idx2partner, idx2env)
    
    return idx2env, idx2ego, idx2partner, prompts

def get_train_data_lbf(args):
    with open(args.desc_load, 'r') as f:
        desc_dict = json.load(f)

    idx2env = []
    idx2ego = []
    idx2partner = []

    for i, name in enumerate(desc_dict):
        target_food_level = re.findall('[0-9]', name)[0]

        env_config = args.env_config
        env_config["layout_name"] = f"lbf_spread_{target_food_level}-v0"
        env = gym.make(args.env, **env_config)

        partner_load = os.path.join(args.model_load, name + ".zip")
        partner = gen_fixed({}, "PPO", partner_load)
        
        ego_load = os.path.join(args.model_load, f"ego_{target_food_level}.zip")
        ego = gen_fixed({}, "PPO", ego_load)
        
        idx2env.append(env)
        idx2ego.append(ego)
        idx2partner.append(partner)
    
    label_list = np.arange(len(idx2ego))
    prompt_list = get_prompt_embeddings(torch.tensor(label_list)).numpy()
    return idx2env, idx2ego, idx2partner, prompt_list

def get_train_data_assistive(args):
    with open(args.desc_load, 'r') as f:
        desc_dict = json.load(f)

    task2env = {"drinking": "DrinkingJacoHuman-v1", "feeding": "FeedingJacoHuman-v1",
                "scratchitch": "ScratchItchJacoHuman-v1", "bedbathing": "BedBathingJacoHuman-v1"} 

    idx2env = []
    idx2ego = []
    idx2partner = []

    module = importlib.import_module('assistive_gym.envs')
    for i, task_name in enumerate(desc_dict):
        env_name = task2env[task_name].split('-')[0]
        env_class = getattr(module, env_name.split('-')[0] + 'Env')

        env = env_class()

        partner_load = os.path.join(args.model_load, task_name + "_alt.zip")
        partner = gen_fixed({}, "PPO", partner_load)
        
        ego_load = os.path.join(args.model_load, task_name + "_ego.zip")
        ego = gen_fixed({}, "PPO", ego_load)
        
        idx2env.append(env)
        idx2ego.append(ego)
        idx2partner.append(partner)
        
    label_list = np.arange(len(idx2ego))
    prompt_list = get_prompt_embeddings(torch.tensor(label_list)).numpy()
    return idx2env, idx2ego, idx2partner, prompt_list


def preset(args):
    if args.env_config is None:
        args.env_config = {'layout_name': args.layout}
    if args.model_load is None:
        if args.env == "OvercookedMultiEnv-v0":
            args.model_load = "diffusion_human_ai/models/%s" % (args.layout)
        elif args.env == "AssistiveMultiEnv-v0":
            args.model_load = "diffusion_human_ai/models/assistive"
        elif args.env == "LBFMultiEnv-v0":
            args.model_load = "diffusion_human_ai/models/lbf_spread" 
    if args.vae_load is None:
        args.vae_load = os.path.join(args.model_load, 'vae.pth')
    if args.diffusion_save is None:
        args.diffusion_save = os.path.join(args.model_load, 'diffusion_robust.pth')
    if args.desc_load is None:
        if args.diverse_desc:
            args.desc_load = os.path.join(args.model_load, 'diverse_descriptions.json')
        else:
            args.desc_load = os.path.join(args.model_load, 'descriptions.json')
    return args
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='OvercookedMultiEnv-v0')
    parser.add_argument('--layout', type=str, default='diverse_coordination')
    parser.add_argument('--env_config', type=json.loads, default=None)
    parser.add_argument('--horizon', type=int, default=400)
    parser.add_argument('--rollout_games', type=int, default=10)
    parser.add_argument('--seed', type=int, default=42)
    
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--n_epochs', type=int, default=1000)
    parser.add_argument('--latent_dim', type=int, default=64)
    
    parser.add_argument('--diffusion_save', type=str, default=None)
    parser.add_argument('--vae_load', type=str, default=None)
    parser.add_argument('--model_load', type=str, default=None)
    parser.add_argument('--traj_save', type=str, default='trajs')
    parser.add_argument('--desc_load', type=str, default=None)
    parser.add_argument('--event_info_dim', type=int, default=19)

    parser.add_argument('--multi_batch', type=bool, default=True)
    parser.add_argument('--framestack', '-f', type=int, default=1)
    parser.add_argument('--record', type=str, default=None)
    parser.add_argument('--diverse_desc', type=bool, default=True)
    
    parser.add_argument('--latent_shape', type=tuple, default=(8, 8))
    parser.add_argument('--context_dim', type=int, default=320)
    parser.add_argument('--use_translator', type=bool, default=False)
    parser.add_argument('--diffusion_embed', type=bool, default=True)
    args = parser.parse_args()
    args = preset(args)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    translator_finetuned_path = os.path.join(args.model_load, 'diverse_translator.pth')
    if os.path.exists(translator_finetuned_path) and args.use_translator:
        print("Using finetuned translator...")
        translator = Translator(event_info_dim=args.event_info_dim).to(device)
        translator.load_state_dict(torch.load(translator_finetuned_path))
        args.diffusion_save = os.path.join(args.model_load, 'diffusion_trans.pth')
    else: 
        translator = None
    
    if args.env == "OvercookedMultiEnv-v0":
        env_list, ego_list, alt_list, prompts = get_train_data(args, translator)

        if isinstance(prompts, dict):
            ego_list_repeated = []
            prompt_list = []
            for i, partner_label in enumerate(prompts):
                ego_list_repeated.extend([ego_list[i]] * len(prompts[partner_label]))
                prompt_list.extend(prompts[partner_label])
            ego_list = ego_list_repeated
        else:
            prompt_list = prompts
        
    elif args.env == "AssistiveMultiEnv-v0":
        env_list, ego_list, alt_list, prompt_list = get_train_data_assistive(args)
    elif args.env == "LBFMultiEnv-v0":
        env_list, ego_list, alt_list, prompt_list = get_train_data_lbf(args)

    policy_dataset = AgentDataset(ego_list, prompt_list, args=args)
    dataloader = DataLoader(policy_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=custom_collate)
    
    if args.env == "OvercookedMultiEnv-v0":
        action_dim = 6
        obs_dim = 96
        last_bias = action_dim
        encoder = Conv1dEncoder(bias_shapes=[64, 64, last_bias], latent_dim=args.latent_dim, obs_dim=obs_dim, action_dim=action_dim).to(device)

    elif args.env == "AssistiveMultiEnv-v0":
        action_dim = 7
        obs_dim = 64
        last_bias = action_dim
        encoder = Conv1dEncoder(bias_shapes=[64, 64, last_bias], latent_dim=args.latent_dim, obs_dim=obs_dim, action_dim=action_dim).to(device)

    elif args.env == "LBFMultiEnv-v0":
        action_dim = env_list[0].action_space.n
        obs_dim = env_list[0].observation_space.shape[0]
        last_bias = action_dim
        encoder = Conv1dEncoder(bias_shapes=[64, 64, action_dim], latent_dim=args.latent_dim, \
                                obs_dim=4*(obs_dim // 4 + 1), action_dim=action_dim).to(device)

    hyper_actor = hyperActor(
        act_dim=last_bias,
        obs_dim=obs_dim,
        latent_dim=args.latent_dim,
        allowable_layers=np.array([64]),
        meta_batch_size=args.batch_size,
        device=device
    )
    decoder = HyperDecoder(hyper_actor)
    vae = VAE(encoder=encoder, decoder=decoder).to(device)
    vae.load_state_dict(torch.load(args.vae_load))
    
    generator = torch.Generator(device=device)
    if args.seed is None:
        generator.seed()
    else:
        generator.manual_seed(args.seed)
        
    ddpm = DDPMSampler(generator)

    latent_dim = args.latent_dim
    latent_shape = (int(np.sqrt(latent_dim)), int(np.sqrt(latent_dim)))
    
    seq_len = len(prompt_list[0])
    diffusion = Diffusion(d_context=args.context_dim, context_embed=args.diffusion_embed).to(device)
    train_diffusion(args, ddpm, diffusion, dataloader, vae, device)