import os
import re
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import pickle
import json
import gym
import argparse
from torch.utils.data import Dataset, DataLoader
from .utils import *

from stable_baselines3.ppo.ppo import PPO
from pantheonrl.common.agents import StaticPolicyAgent
from pantheonrl.common.observation import Observation

from diffusion_human_ai.ghn.core import hyperActor
from diffusion_human_ai.ldm.vae import Conv1dEncoder, VAE, HyperDecoder
from diffusion_human_ai.ldm.utils import params_from_policy, get_recon_loss, process_recon_params_list, set_params


def custom_collate(batch):
    policies, transitions, labels = zip(*batch)
    return policies, transitions, labels

def collect_transitions(ego, env, n_episodes=25, n_steps=400):
    env.set_ego_extractor(lambda obs: obs)
    
    transitions = [] # only observations (used for training vae)
    for game in range(n_episodes):
        obs = env.reset()
        transitions.append(obs)
        done = False
        step = 0
        while not done:
            action = ego.get_action(obs, False)
            obs, _, done, _ = env.step(action)
            transitions.append(obs)
            step += 1
            if step >= n_steps:
                break
            
    return transitions

def collect_transitions_lbf(ego, env, n_episodes=200, n_steps=30):
    env.set_ego_extractor(lambda obs: obs)
    
    transitions = [] # only observations (used for training vae)
    for game in range(n_episodes):
        obs = env.reset()
        transitions.append(obs)
        done = False
        step = 0
        while not done:
            action = ego.get_action(obs, False)
            obs, _, done, _ = env.step(action)
            transitions.append(obs)
            step += 1
            if step >= n_steps:
                break
            
    return transitions

def collect_transitions_assistive(ego, partner, env, n_episodes=50, n_steps=200):
    
    transitions = [] # only observations (used for training vae)
    for game in range(n_episodes):
        step = 0
        obs = Observation(env.reset())
        done = False
        while not done:
            # Step the simulation forward. Have the robot take a random action.
            joint_action = np.concatenate([ego.get_action(obs), partner.get_action(obs)])
            raw_obs, reward, done, info = env.step(joint_action)
            obs = Observation(raw_obs)
            transitions.append(obs)
            step += 1
            if step >= n_steps:
                break
            
    return transitions

class AgentDataset(Dataset):
    def __init__(self, ego_list, env_list, label_list, collect_traj=False, args=None):
        self.ego_list = ego_list
        self.env_list = env_list
        self.label_list = label_list
        self.policy_list = [ego.policy for ego in ego_list]
        
        self.transition_list = []
        if collect_traj:
            print('Collecting transitions...')
            if args.env == 'OvercookedMultiEnv-v0':
                self.transition_list = [collect_transitions(ego, env) \
                    for ego, env in tqdm(zip(self.ego_list, self.env_list))]
            elif args.env == 'AssistiveMultiEnv-v0':
                self.transition_list = [collect_transitions_assistive(ego, alt, env) \
                    for ego, alt, env in tqdm(zip(self.ego_list, self.alt_list, self.env_list))]
            elif args.env == 'LBFMultiEnv-v0':
                self.transition_list = [collect_transitions_lbf(ego, env) \
                    for ego, env in tqdm(zip(self.ego_list, self.env_list))]
                
            if not os.path.exists(args.traj_save):
                os.makedirs(args.traj_save)
            traj_path = os.path.join(args.traj_save, f'{args.env}_{args.layout}_transitions_{args.seed}.pickle')

            with open(traj_path, 'wb') as f:
                pickle.dump(self.transition_list, f)
        else:
            print('Loading transitions...')
            traj_path = os.path.join(args.traj_save, f'{args.layout}_transitions_{args.seed}.pickle')
            with open(traj_path, 'rb') as f:
                self.transition_list = pickle.load(f)
        
                
    def __len__(self):
        return len(self.policy_list)
    
    def __getitem__(self, idx):
        policy = self.policy_list[idx]
        transitions = self.transition_list[idx]
        label = self.label_list[idx]
        return policy, transitions, label

def preset(args):
    if args.env_config is None:
        args.env_config = {'layout_name': args.layout}
    if args.policy_dir is None:
        if args.env == "OvercookedMultiEnv-v0":
            args.policy_dir = 'diffusion_human_ai/models/%s' % (args.layout)
        elif args.env == "AssistiveMultiEnv-v0":
            args.policy_dir = "diffusion_human_ai/models/assistive"
        elif args.env == "LBFMultiEnv-v0":
            args.policy_dir = "diffusion_human_ai/models/lbf_spread"
    if args.vae_save is None:
        args.vae_save = os.path.join(args.policy_dir, 'vae.pth')
    if args.desc_load is None:
        args.desc_load = os.path.join(args.policy_dir, 'diverse_descriptions.json')
    return args


def train_vae(vae, dataloader, env_list, args=None):
    vae = vae.to(device)
    recon_loss_fn = nn.KLDivLoss(reduction="batchmean", log_target=True).to(device)
    optimizer = torch.optim.Adam(list(vae.encoder.parameters()) + list(vae.decoder.policy.ghn.parameters()), args.lr, weight_decay=1e-5)

    for epoch in range(args.n_epochs):
        recon_loss = 0
        kl_loss = 0
        total_loss = 0
        for policy_batch, transition_batch, label_batch in dataloader:
            torch.cuda.empty_cache()
            batch_size = args.batch_size
            batch_recon_loss = 0

            params_batch = params_from_policy(policy_batch)

            latents, mean, logvar = vae.encoder(params_batch)
            recon_params_list = vae.decoder(latents)
            
            loadable_params_list = process_recon_params_list(recon_params_list)
            
            for i, params in enumerate(loadable_params_list):
                pi = policy_batch[i]
                transition = transition_batch[i] # only observations

                ego_hat = PPO(policy='MlpPolicy', env=env_list[label_batch[i]])
                pi_hat = set_params(ego_hat.policy, params)
                
                batch_recon_loss += get_recon_loss(pi_hat, pi, transition, recon_loss_fn)
            
            batch_recon_loss /= batch_size
            batch_kl_loss = torch.mean(-0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar), 1), 0)
            
            batch_loss = batch_recon_loss + args.kl_coef * batch_kl_loss

            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()

            recon_loss += batch_recon_loss
            kl_loss += batch_kl_loss
            total_loss += batch_loss

        recon_loss /= len(dataloader)
        kl_loss /= len(dataloader)
        total_loss /= len(dataloader)
        print(f"epoch:{epoch + 1}, total_loss:{total_loss:.6f}, recon_loss:{recon_loss:.6f}, kl_loss:{kl_loss:.6f}")

        if (epoch + 1) % args.save_interval == 0:
            torch.save(vae.state_dict(), args.vae_save)
            print('Model saved at', args.vae_save)
                    
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='OvercookedMultiEnv-v0')
    parser.add_argument('--layout', type=str, default=None)
    parser.add_argument('--env_config', type=json.loads, default=None)
    
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--n_epochs', type=int, default=1000)
    parser.add_argument('--kl_coef', type=float, default=1e-6)
    parser.add_argument('--latent_dim', type=int, default=64)
    
    parser.add_argument('--policy_dir', type=str, default=None)
    parser.add_argument('--cluster_load', type=str, default=None)
    parser.add_argument('--desc_load', type=str, default=None)
    
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--save_interval', type=int, default=20)
    parser.add_argument('--vae_save', type=str, default=None)
    parser.add_argument('--traj_save', type=str, default='trajs')
    parser.add_argument('--framestack', '-f', type=int, default=1)
    parser.add_argument('--record', type=str, default=None)
    args = parser.parse_args()
    args = preset(args)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    
    if args.env == 'OvercookedMultiEnv-v0':
        ego_list, env_list, label_list = get_train_data(args)
        policy_dataset = AgentDataset(ego_list, env_list, label_list, collect_traj=True, args=args)
        dataloader = DataLoader(policy_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=custom_collate)
        encoder = Conv1dEncoder(bias_shapes=[64, 64, 6], latent_dim=args.latent_dim, obs_dim=96, action_dim=6).to(device)
        hyper_actor = hyperActor(
            act_dim=env_list[0].action_space.n,
            obs_dim=env_list[0].observation_space.shape[0],
            latent_dim=args.latent_dim,
            allowable_layers=np.array([64]),
            meta_batch_size=args.batch_size,
            device=device
        )
    elif args.env == 'AssistiveMultiEnv-v0':
        obs_dim = 64
        action_dim = 7 
        ego_list, partner_list, env_list, label_list = get_train_data_assistive(args)
        # ego_list, env_list, label_list = get_train_data(args)
        policy_dataset = AgentDataset(ego_list, partner_list, env_list, label_list, collect_traj=False, args=args)
        encoder = Conv1dEncoder(bias_shapes=[64, 64, action_dim], latent_dim=args.latent_dim, obs_dim=obs_dim, action_dim=action_dim).to(device)
        hyper_actor = hyperActor(
            act_dim=action_dim,
            obs_dim=obs_dim,
            latent_dim=args.latent_dim,
            allowable_layers=np.array([64]),
            meta_batch_size=args.batch_size,
            device=device
        )

    elif args.env == 'LBFMultiEnv-v0':
        ego_list, env_list, label_list = get_train_data_lbf(args)
        policy_dataset = AgentDataset(ego_list, env_list, label_list, collect_traj=True, args=args)
        dataloader = DataLoader(policy_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=custom_collate)
        
        encoder = Conv1dEncoder(bias_shapes=[64, 64, 6], latent_dim=args.latent_dim, \
                                obs_dim=4*(env_list[0].observation_space.shape[0] // 4 + 1), \
                                action_dim=env_list[0].action_space.n).to(device)
        hyper_actor = hyperActor(
            act_dim=env_list[0].action_space.n,
            obs_dim=env_list[0].observation_space.shape[0],
            latent_dim=args.latent_dim,
            allowable_layers=np.array([64]),
            meta_batch_size=args.batch_size,
            device=device
        )
    decoder = HyperDecoder(hyper_actor)
    vae = VAE(encoder=encoder, decoder=decoder).to(device)
    
    train_vae(vae, dataloader, env_list, args=args)