import os
import re
import json
import gym
import torch
import argparse
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from trainer import generate_env, DIVERSE_COORPERATION_STYLE_LIST, DIVERSE_ORDERS_STYLE_LIST, CENTER_POTS_STYLE_LIST, CROSSWAY_STYLE_LIST
from tester import gen_fixed

from overcooked_ai_py.agents.agent import AgentPair, PantheonRLAgent
from overcooked_ai_py.mdp.overcooked_mdp import EVENT_TYPES

from diffusion_human_ai.translator.vae_fc import Translator

from transformers import BertTokenizer, BertModel
from .utils import *


def custom_collate(batch):
    human_descriptions, event_infos, labels= zip(*batch)
    return human_descriptions, torch.tensor(np.array(event_infos), device=device), torch.tensor(labels, device=device)

class DescriptionDataset(Dataset):
    def __init__(self, desc_dict, info_list):
        self.human_descriptions = []
        self.event_infos = []
        self.label_list = []
        
        for i, partner_name in enumerate(desc_dict):
            self.human_descriptions.extend(desc_dict[partner_name])
            self.event_infos.extend([info_list[i]] * len(desc_dict[partner_name]))
            self.label_list.extend([i] * len(desc_dict[partner_name]))
    
        if args.multi_batch:
            self.human_descriptions = self.human_descriptions * args.batch_size
            self.event_infos = self.event_infos * args.batch_size
            self.label_list = self.label_list * args.batch_size
    
    def __len__(self):
        return len(self.human_descriptions)
    
    def __getitem__(self, idx):
        human_description = self.human_descriptions[idx]
        event_info = self.event_infos[idx]
        label = self.label_list[idx]
        return human_description, event_info, label
            
            
def preset(args):
    if args.env_config is None:
        args.env_config = {'layout_name': args.layout}
    if args.model_load is None:
        args.model_load = "diffusion_human_ai/models/%s" % (args.layout)
    if args.desc_load is None:
        if args.diverse_desc:
            args.desc_load = os.path.join(args.model_load, "diverse_descriptions.json")
        else:
            args.desc_load = os.path.join(args.model_load, "descriptions.json")
    if args.bert_path is None:
        args.bert_path = "diffusion_human_ai/models/bert-base-uncased"
    if args.finetuned_bert_path is None:
        args.finetuned_bert_path = os.path.join(args.model_load, "finetuned_bert")
    if args.translator_save is None:
        args.translator_save = os.path.join(args.model_load, "translator.pth")

    return args

def get_event_infos(args, ego_list, alt_list, env_list=None):
    if not env_list:
        env, altenv = generate_env(args)
        base_env = env.base_env
        
    partner_info_list = []
    ego_info_list = []

    print("Collecting event infos of partner...")

    for i in range(len(ego_list)):
        if env_list:
            env = env_list[i]
            base_env = env.base_env
        
        ego_agent = PantheonRLAgent(ego_list[i], env)
        alt_agent = PantheonRLAgent(alt_list[i], env)
        agent_pair = AgentPair(ego_agent, alt_agent)
        
        base_env.get_rollouts(agent_pair, num_games=args.rollout_games)

        event_list = EVENT_TYPES         
        for i, pos in enumerate(base_env.mdp._get_terrain_type_pos_dict()[' ']):
            pos_key = 'pos_%s_%s' % (pos[0], pos[1])
            event_list.append(pos_key)

        ego_stats = {event: [] for event in event_list}
        partner_stats = {event: [] for event in event_list}
        for game_stats in base_env.game_stats_buffer:
            for event in game_stats:

                if not isinstance(game_stats[event], list): continue
                if 'pos_' in event: continue

                ego_stats[event].append(len(game_stats[event][0]) / args.horizon)
                partner_stats[event].append(len(game_stats[event][1]) / args.horizon)

        partner_info = [np.mean(values) for values in partner_stats.values()]
        partner_info_list.append(partner_info)

        ego_info = [np.mean(values) for values in ego_stats.values()]
        ego_info_list.append(ego_info)

    # processed_infos = process_infos_with_position_embedding(partner_info_list, pos_embed=False)
    processed_infos = process_infos(partner_info_list)

    return processed_infos

def get_train_data(args):
    with open(args.desc_load, 'r') as f:
        desc_dict = json.load(f)
        if args.diverse_desc:
            desc_dict = desc_dict['train']
            desc_dict = {key: desc_dict[key][: args.n_train_desc] for key in desc_dict}
        
    idx2env = []
    idx2ego = []
    idx2partner = []
    for i, name in enumerate(desc_dict):
        env_config = args.env_config
        if env_config["layout_name"] == "diverse_coordination":
            masked_events = DIVERSE_COORPERATION_STYLE_LIST[i]
            env_config["masked_events"] = masked_events
        elif env_config["layout_name"] == "diverse_orders":
            style_id = i % len(DIVERSE_ORDERS_STYLE_LIST)
            masked_events = DIVERSE_ORDERS_STYLE_LIST[style_id]
            env_config["masked_events"] = masked_events
        elif env_config["layout_name"] == "center_pots":
            masked_events = CENTER_POTS_STYLE_LIST[i]
            env_config["masked_events"] = masked_events
        elif env_config["layout_name"] == "crossway":
            masked_events = CROSSWAY_STYLE_LIST[i]
            env_config["masked_events"] = masked_events
            
        env = gym.make(args.env, **env_config)
        partner_load = os.path.join(args.model_load, name + ".zip")
        partner = gen_fixed({}, "PPO", partner_load)
        
        idx = re.findall('[0-9]', name)[0]
        ego_load = os.path.join(args.model_load, f"ego_{idx}.zip")
        ego = gen_fixed({}, "PPO", ego_load)
        
        env.add_partner_agent(partner)
        
        idx2env.append(env)
        idx2ego.append(ego)
        idx2partner.append(partner)

    prompt_list = get_event_infos(args, idx2ego, idx2partner, idx2env)
    return desc_dict, prompt_list

def get_train_data_(args):
    """Get human description dict and event-based infos"""
    with open(args.desc_load, 'r') as f:
        desc = json.load(f)
        
    label2ego = {}
    label2partner = {}
    label2desc = {}
    for i in range(10):
        for prefix in ['partner', 'partner_mid']:
            partner_label = f'{prefix}_{i}'
            partner_load = os.path.join(args.model_load, partner_label + '.zip')
            partner = gen_fixed({}, 'PPO', partner_load)
            label2partner[partner_label] = partner
            label2desc[partner_label] = desc[partner_label]
    
    for i in range(10):
        for prefix in ['ego', 'ego_mid']:
            ego_label = f'{prefix}_{i}'
            ego_load = os.path.join(args.model_load, ego_label + '.zip')
            ego = gen_fixed({}, 'PPO', ego_load)
            label2ego[ego_label] = ego
            
    info_list = get_event_infos(args, list(label2ego.values()), list(label2partner.values()))
        
    return label2desc, info_list

def train(translator, dataloader, args):
    translator = translator.to(device)
    optimizer = torch.optim.Adam(translator.vae.parameters(), args.lr, weight_decay=1e-5)
    
    for epoch in range(args.n_epochs):
        recon_loss = 0
        kl_loss = 0
        total_loss = 0
        
        for desc_batch, info_batch, label_batch in dataloader:
            torch.cuda.empty_cache()
            
            recon_infos, latents, mean, logvar = translator(desc_batch)
            
            batch_recon_loss = F.mse_loss(recon_infos, info_batch)
            batch_kl_loss = torch.mean(-0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar), 1), 0)
            batch_loss = batch_recon_loss + args.kl_coef * batch_kl_loss
          
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()
            
            recon_loss += batch_recon_loss
            kl_loss += batch_kl_loss

            total_loss += batch_loss
        
        recon_loss /= len(dataloader)
        kl_loss /= len(dataloader)
        total_loss /= len(dataloader)
        
        print(f"Epoch:{epoch + 1}, total_loss:{total_loss:.6f}, recon_loss:{recon_loss:.6f}, kl_loss:{kl_loss:.6f}")
        
        if (epoch + 1) % args.save_interval == 0:
            torch.save(translator.state_dict(), args.translator_save)
            print("Model saved at", args.translator_save)
        
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='OvercookedMultiEnv-v0')
    parser.add_argument('--layout', type=str, default='diverse_coordination')
    parser.add_argument('--env_config', type=json.loads, default=None)
    
    parser.add_argument('--horizon', type=int, default=400)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--rollout_games', type=int, default=10)
    
    parser.add_argument('--event_info_dim', type=int, default=25)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--latent_dim', type=int, default=64)
    parser.add_argument('--n_epochs', type=int, default=1000)
    parser.add_argument('--kl_coef', type=float, default=1e-6)
    
    parser.add_argument('--save_interval', type=int, default=20)
    parser.add_argument('--multi_batch', type=bool, default=True)
    
    parser.add_argument('--model_load', type=str, default=None)
    parser.add_argument('--translator_save', type=str, default=None)
    parser.add_argument('--desc_load', type=str, default=None)
    parser.add_argument('--bert_path', type=str, default=None)
    parser.add_argument('--framestack', '-f', type=int, default=1)
    parser.add_argument('--record', type=str, default=None)

    parser.add_argument('--diverse_desc', type=bool, default=True)
    parser.add_argument('--n_train_desc', type=int, default=10)
    parser.add_argument('--use_finetuned_bert', type=bool, default=True)
    parser.add_argument('--finetuned_bert_path', type=str, default=None)
    parser.add_argument('--max_seq_len', type=int, default=32)
    parser.add_argument('--bert_output', type=str, default='last_hidden_state')

    args = parser.parse_args()
    args = preset(args)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    descriptions, event_infos = get_train_data(args) 
    args.event_info_dim = len(event_infos[0])
    description_dataset = DescriptionDataset(descriptions, event_infos)
    dataloader = DataLoader(description_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=custom_collate)
    
    if args.use_finetuned_bert:
        finetuned_bert = BertModel.from_pretrained(args.finetuned_bert_path).to(device)
    pooler_output = True if args.bert_output == 'pooler_output' else False
    translator = Translator(args.event_info_dim, finetuned_bert=finetuned_bert, max_seq_len=args.max_seq_len, pooler_output=pooler_output).to(device)
    
    train(translator, dataloader, args)