import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
from utils import RNNEncoder, EnsembleDynamicsModel, ReplayBuffer
import argparse
import os

class TemporalConsistentRepresentationLearner:
    def __init__(self, state_dim, action_dim, z_dim, hidden_dim, ensemble_size=3, device='cpu'):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.z_dim = z_dim
        self.hidden_dim = hidden_dim
        self.ensemble_size = ensemble_size
        self.device = device
        
        self.encoder = RNNEncoder(state_dim, action_dim, hidden_dim, z_dim).to(device)
        self.dynamics_model = EnsembleDynamicsModel(state_dim, action_dim, z_dim, hidden_dim, ensemble_size).to(device)
        
        self.encoder_optimizer = optim.Adam(self.encoder.parameters(), lr=1e-4)
        self.dynamics_optimizer = optim.Adam(self.dynamics_model.parameters(), lr=1e-4)
    
    def compute_kl_divergence(self, mean, std):
        prior_mean = torch.zeros_like(mean)
        prior_std = torch.ones_like(std)
        
        kl = 0.5 * (torch.log(prior_std**2 / std**2) + (std**2 + (mean - prior_mean)**2) / prior_std**2 - 1).sum(dim=-1)
        return kl
    
    def e_step(self, trajectories):
        trajectories = torch.FloatTensor(trajectories).to(self.device)
        
        self.encoder_optimizer.zero_grad()
        
        mean_z, std_z = self.encoder(trajectories)
        
        z = Normal(mean_z, std_z).sample()
        z_expanded = z.unsqueeze(1).repeat(1, trajectories.shape[1], 1)
        
        states = trajectories[:, :-1, :self.state_dim]
        actions = trajectories[:, :-1, self.state_dim:self.state_dim+self.action_dim]
        next_states = trajectories[:, 1:, :self.state_dim]
        
        model_idx = np.random.randint(0, self.ensemble_size)
        log_prob = self.dynamics_model.compute_log_prob(states, actions, next_states, z_expanded[:, :-1], model_idx)
        log_prob = log_prob.sum(dim=1)
        
        kl_div = self.compute_kl_divergence(mean_z, std_z)
        
        elbo = (log_prob - kl_div).mean()
        loss = -elbo
        
        loss.backward()
        self.encoder_optimizer.step()
        
        return mean_z, std_z, loss.item()
    
    def m_step(self, trajectories, mean_z, std_z):
        trajectories = torch.FloatTensor(trajectories).to(self.device)
        z = Normal(mean_z, std_z).sample()
        z_expanded = z.unsqueeze(1).repeat(1, trajectories.shape[1], 1)
        
        self.dynamics_optimizer.zero_grad()
        
        states = trajectories[:, :-1, :self.state_dim]
        actions = trajectories[:, :-1, self.state_dim:self.state_dim+self.action_dim]
        next_states = trajectories[:, 1:, :self.state_dim]
        
        total_log_prob = 0
        for model_idx in range(self.ensemble_size):
            log_prob = self.dynamics_model.compute_log_prob(states, actions, next_states, z_expanded[:, :-1], model_idx)
            total_log_prob += log_prob.sum(dim=1).mean()
        
        loss = -total_log_prob
        
        loss.backward()
        self.dynamics_optimizer.step()
        
        return loss.item()
    
    def train(self, replay_buffer, num_steps, batch_size, seq_len, em_iterations=1):
        for step in range(num_steps):
            trajectories = replay_buffer.sample_trajectories(batch_size, seq_len)
            if trajectories is None:
                continue
            for em_iter in range(em_iterations):
                mean_z, std_z, e_loss = self.e_step(trajectories)
                m_loss = self.m_step(trajectories, mean_z, std_z)
    
    def get_representation(self, trajectories):
        trajectories = torch.FloatTensor(trajectories).to(self.device)
        mean_z, std_z = self.encoder(trajectories)
        return mean_z.detach().cpu().numpy()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dir", default="./logs/Offline")
    parser.add_argument("--policy", default="SQL", help='policy to use') # support IQL, SQL
    parser.add_argument("--env", default="ant-kinematic")
    parser.add_argument('--srctype', default="random", help='dataset type used in the source domain (and the target domain)') # only useful when source domain is offline
    parser.add_argument('--mode', default=3, type=int, help='the training mode, there are four types, 0: online-online, 1: offline-online, 2: online-offline, 3: offline-offline')
    parser.add_argument("--seed", default=100, type=int)
    parser.add_argument("--save_model", default=True, type=bool)        # Save model and optimizer parameters
    parser.add_argument('--tar_env_interact_interval', help='interval of interacting with target env', default=10, type=int)
    parser.add_argument('--max_step', default=int(1e6), type=int)  # the maximum gradient step for off-dynamics rl learning
    parser.add_argument('--params', default=None, help='Hyperparameters for the adopted algorithm, ought to be in JSON format')
    parser.add_argument('--device', default='cuda:0', type=str)
    parser.add_argument('--num_steps', default=int(1e6), type=int)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--seq_len', default=10, type=int)
    parser.add_argument('--hidden_dim', default=256, type=int)
    parser.add_argument('--ensemble_size', default=3, type=int)
    parser.add_argument('--z_dim', default=256, type=int)
    args = parser.parse_args()  
    
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    
    if '_' in args.env:
        args.env = args.env.replace('_', '-')
    src_env_name = args.env.split('-')[0]
    
    if "halfcheetah" in args.env:
        src_env = HalfCheetahEnv
    elif "hopper" in args.env:
        src_env = HopperEnv
    elif "walker2d" in args.env:
        src_env = Walker2dEnv
    elif "ant" in args.env:
        src_env = AntEnv
    else:
        raise NotImplementedError
    
    #load env and dataset
    if args.env in ["halfcheetah", "hopper", "walker2d", "ant"]:
        src_eval_env = gym.make(args.env + "-" + args.srctype + "-v2")
        src_eval_env.seed(args.seed)
        src_dataset = d4rl.qlearning_dataset(src_eval_env)
        
        size = int(src_dataset["observations"].shape[0] * 0.1)

        ind = np.random.randint(0, src_dataset["observations"].shape[0], size=size)
    
        src_dataset = {
            "observations": src_dataset['observations'][ind],
            "actions": src_dataset['actions'][ind],
            "next_observations": src_dataset['next_observations'][ind],
            "rewards": src_dataset['rewards'][ind],
            "terminals": src_dataset['terminals'][ind],
        }
        
    else:
        src_eval_env = TimeLimit(
                    src_env(xml_file=f"{str(Path(__file__).parent.absolute())}/envs/mujoco/assets/{args.env.replace('-', '_')}.xml",),
                    max_episode_steps=1000          
                )
        src_eval_env.seed(args.seed)
        
        src_dataset_path = f"{str(Path(__file__).parent.absolute())}/dataset/source/{args.env}-{args.srctype}.hdf5"
        data_dict = {}
        with h5py.File(src_dataset_path, 'r') as dataset_file:
            for k in tqdm(get_keys(dataset_file), desc="load datafile"):
                try:  # first try loading as an array
                    data_dict[k] = dataset_file[k][:]
                except ValueError as e:  # try loading as a scalar
                    data_dict[k] = dataset_file[k][()]
        src_dataset = data_dict
        
    state_dim = src_eval_env.observation_space.shape[0]
    action_dim = src_eval_env.action_space.shape[0] 
    
    replay_buffer = ReplayBuffer(state_dim, action_dim)
    
    replay_buffer.convert_dataset(src_dataset)
    
    learner = TemporalConsistentRepresentationLearner(
        state_dim, action_dim, args.z_dim, args.hidden_dim, args.ensemble_size, device
    )
    
    learner.train(replay_buffer, args.num_steps, args.batch_size, args.seq_len)
    
    all_trajectories = replay_buffer.get_all_trajectories()
    if all_trajectories is not None:
        z = learner.get_representation(all_trajectories)
        replay_buffer.relabel_with_z(z)
        print("Successfully relabeled replay buffer with learned representations.")
        
    relabeled_dataset = replay_buffer.convert_to_dataset()
    relabeled_dataset_path = f"{str(Path(__file__).parent.absolute())}/dataset/source/{args.env}-{args.srctype}.hdf5"
    with h5py.File(relabeled_dataset_path, 'w') as dataset_file:
        for k, v in relabeled_dataset.items():
            dataset_file.create_dataset(k, data=v)