# Inspired by:
# 1. paper for SAC-N: https://arxiv.org/abs/2110.01548
# 2. implementation: https://github.com/snu-mllab/EDAC
import math
import os
import random
import uuid
from copy import deepcopy
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import d4rl
import gym
import numpy as np
import pyrallis
import torch
import torch.nn as nn
import wandb
from torch.distributions import Normal
from tqdm import trange
import sys

import mediapy as media

# Set an ablation flag
ABLAT = True

@dataclass
class TrainConfig:
    # wandb params
    project: str = "CORL"
    group: str = "EDAC-D4RL"
    name: str = "EDAC"
    # model params
    hidden_dim: int = 256
    num_critics: int = 5 # 10
    gamma: float = 0.99
    tau: float = 5e-3
    eta: float = 1.0
    actor_learning_rate: float = 3e-4
    critic_learning_rate: float = 3e-4
    alpha_learning_rate: float = 3e-4
    max_action: float = 1.0
    # training params
    buffer_size: int = 1_000_000
    env_name: str = "halfcheetah-medium-v2"
    batch_size: int = 256
    num_epochs: int = 1000 # 3000
    num_updates_on_epoch: int = 1000
    normalize_reward: bool = False
    # evaluation params
    eval_episodes: int = 10
    eval_every: int = 5
    # general params
    checkpoints_path: Optional[str] = None
    deterministic_torch: bool = False
    train_seed: int = 10
    eval_seed: int = 42
    log_every: int = 100
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    mi_w: float = 1.0

    def __post_init__(self):
        self.name = f"{self.name}-{self.env_name}--seed{str(self.train_seed)}--miW{str(self.mi_w)}"
        if self.checkpoints_path is not None:
            self.checkpoints_path = os.path.join(self.checkpoints_path, self.name)


# general utils
TensorBatch = List[torch.Tensor]


def soft_update(target: nn.Module, source: nn.Module, tau: float):
    for target_param, source_param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)


def wandb_init(config: dict) -> None:
    prefix = "Ablation-" if ABLAT else ""
    
    # Ensure eta and num_critics are strings
    eta_str = str(config.get("eta", ""))
    critics_str = str(config.get("num_critics", ""))
    
    # Build name and group
    name = f"{prefix}MI(One_SameCriticW){config['name']}_eta{eta_str}_critics{critics_str}"
    group = prefix + config["group"]
    
    # Initialize wandb
    wandb.init(
        config=config,
        project="MI" + config["project"],
        group=group,
        name=name,
    )
    wandb.run.save()

def set_seed(
    seed: int, env: Optional[gym.Env] = None, deterministic_torch: bool = False
):
    if env is not None:
        env.seed(seed)
        env.action_space.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    #torch.use_deterministic_algorithms(deterministic_torch)
    if deterministic_torch:
        torch.set_deterministic()


def wrap_env(
    env: gym.Env,
    state_mean: Union[np.ndarray, float] = 0.0,
    state_std: Union[np.ndarray, float] = 1.0,
    reward_scale: float = 1.0,
) -> gym.Env:
    def normalize_state(state):
        return (state - state_mean) / state_std

    def scale_reward(reward):
        return reward_scale * reward

    env = gym.wrappers.TransformObservation(env, normalize_state)
    if reward_scale != 1.0:
        env = gym.wrappers.TransformReward(env, scale_reward)
    return env


class ReplayBuffer:
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        buffer_size: int,
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
        n_agent: int = 9,
    ):
        self._buffer_size = buffer_size
        self._pointer = 0
        self._size = 0

        self._states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._actions = torch.zeros(
            (buffer_size, action_dim), dtype=torch.float32, device=device
        )
        self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._next_states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._agent_prob = torch.ones((buffer_size, n_agent), dtype=torch.float32, device=device)
        self._device = device

    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
        return torch.tensor(data, dtype=torch.float32, device=self._device)

    # Loads data in d4rl format, i.e. from Dict[str, np.array].
    def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):
        if self._size != 0:
            raise ValueError("Trying to load data into non-empty replay buffer")
        n_transitions = data["observations"].shape[0]
        if n_transitions > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        self._states[:n_transitions] = self._to_tensor(data["observations"])
        self._actions[:n_transitions] = self._to_tensor(data["actions"])
        self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._next_states[:n_transitions] = self._to_tensor(data["next_observations"])
        self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None])
        self._size += n_transitions
        self._pointer = min(self._size, n_transitions)

        print(f"Dataset size: {n_transitions}")

    def sample(self, batch_size: int) -> TensorBatch:
        indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size)
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        dones = self._dones[indices]
        agents_prob = self._agent_prob[indices]
        return [states, actions, rewards, next_states, dones, agents_prob, indices]

    
    def add_agent_prob(self, agents_prob, indices):
        #print("dd_agent_prob",self._agent_prob[indices])
        self._agent_prob[indices] = agents_prob
        #print("dd_agent_prob2", self._agent_prob[indices])
    def add_transition(self):
        # Use this method to add new data into the replay buffer during fine-tuning.
        raise NotImplementedError


# SAC Actor & Critic implementation
class VectorizedLinear(nn.Module):
    def __init__(self, in_features: int, out_features: int, ensemble_size: int):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.ensemble_size = ensemble_size

        self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features))
        self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))

        self.reset_parameters()

    def reset_parameters(self):
        # default pytorch init for nn.Linear module
        for layer in range(self.ensemble_size):
            nn.init.kaiming_uniform_(self.weight[layer], a=math.sqrt(5))

        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # input: [ensemble_size, batch_size, input_size]
        # weight: [ensemble_size, input_size, out_size]
        # out: [ensemble_size, batch_size, out_size]
        return x @ self.weight + self.bias


class Actor(nn.Module):
    def __init__(
        self, state_dim: int, action_dim: int, hidden_dim: int, max_action: float = 1.0
    ):
        super().__init__()
        self.trunk = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        # with separate layers works better than with Linear(hidden_dim, 2 * action_dim)
        self.mu = nn.Linear(hidden_dim, action_dim)
        self.log_sigma = nn.Linear(hidden_dim, action_dim)

        # init as in the EDAC paper
        for layer in self.trunk[::2]:
            torch.nn.init.constant_(layer.bias, 0.1)

        torch.nn.init.uniform_(self.mu.weight, -1e-3, 1e-3)
        torch.nn.init.uniform_(self.mu.bias, -1e-3, 1e-3)
        torch.nn.init.uniform_(self.log_sigma.weight, -1e-3, 1e-3)
        torch.nn.init.uniform_(self.log_sigma.bias, -1e-3, 1e-3)

        self.action_dim = action_dim
        self.max_action = max_action

    def forward(
        self,
        state: torch.Tensor,
        deterministic: bool = False,
        need_log_prob: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        hidden = self.trunk(state)
        mu, log_sigma = self.mu(hidden), self.log_sigma(hidden)

        # clipping params from EDAC paper, not as in SAC paper (-20, 2)
        log_sigma = torch.clip(log_sigma, -5, 2)
        policy_dist = Normal(mu, torch.exp(log_sigma))

        if deterministic:
            action = mu
        else:
            action = policy_dist.rsample()

        tanh_action, log_prob = torch.tanh(action), None
        if need_log_prob:
            # change of variables formula (SAC paper, appendix C, eq 21)
            #log_prob = policy_dist.log_prob(action).sum(axis=-1)
            #log_prob = log_prob - torch.log(1 - tanh_action.pow(2) + 1e-6).sum(axis=-1)
            log_prob = policy_dist.log_prob(action)
            log_prob = log_prob - torch.log(1 - tanh_action.pow(2) + 1e-6)

        return tanh_action * self.max_action, log_prob

    @torch.no_grad()
    def act(self, state: np.ndarray, device: str) -> np.ndarray:
        deterministic = not self.training
        state = torch.tensor(state, device=device, dtype=torch.float32)
        action = self(state, deterministic=deterministic)[0].cpu().numpy()
        return action

    @torch.no_grad()
    def get_prob(self, state: np.ndarray, action: np.ndarray) -> np.ndarray:
        #deterministic = not self.training
        #state = torch.tensor(state, device=device, dtype=torch.float32)
        hidden = self.trunk(state)
        mu, log_sigma = self.mu(hidden), self.log_sigma(hidden)

        tanh_action, log_prob = torch.tanh(action), None

        # clipping params from EDAC paper, not as in SAC paper (-20, 2)
        log_sigma = torch.clip(log_sigma, -5, 2)
        
        policy_dist = Normal(mu, torch.exp(log_sigma))
        log_prob = policy_dist.log_prob(action)
        #log_prob = log_prob - torch.log(1 - tanh_action.pow(2) + 1e-6)

        return log_prob


class VectorizedCritic(nn.Module):
    def __init__(
        self, state_dim: int, action_dim: int, hidden_dim: int, num_critics: int
    ):
        super().__init__()
        self.critic = nn.Sequential(
            VectorizedLinear(state_dim + action_dim, hidden_dim, num_critics),
            nn.ReLU(),
            VectorizedLinear(hidden_dim, hidden_dim, num_critics),
            nn.ReLU(),
            VectorizedLinear(hidden_dim, hidden_dim, num_critics),
            nn.ReLU(),
            VectorizedLinear(hidden_dim, 1, num_critics),
        )
        # init as in the EDAC paper
        for layer in self.critic[::2]:
            torch.nn.init.constant_(layer.bias, 0.1)

        torch.nn.init.uniform_(self.critic[-1].weight, -3e-3, 3e-3)
        torch.nn.init.uniform_(self.critic[-1].bias, -3e-3, 3e-3)

        self.num_critics = num_critics

    def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        # [..., batch_size, state_dim + action_dim]
        state_action = torch.cat([state, action], dim=-1)
        if state_action.dim() != 3:
            assert state_action.dim() == 2
            # [num_critics, batch_size, state_dim + action_dim]
            state_action = state_action.unsqueeze(0).repeat_interleave(
                self.num_critics, dim=0
            )
        assert state_action.dim() == 3
        assert state_action.shape[0] == self.num_critics
        # [num_critics, batch_size]
        q_values = self.critic(state_action).squeeze(-1)
        return q_values


class EDAC:
    def __init__(
        self,
        actors: List[Actor],
        actor_optimizers: List[torch.optim.Optimizer],
        critics: List[VectorizedCritic],
        critic_optimizers: List[torch.optim.Optimizer],
        gamma: float = 0.99,
        tau: float = 0.005,
        eta: float = 1.0,
        alpha_learning_rate: float = 1e-4,
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
        mi_w: float = 1.0,
    ):
        self.device = device
        self.n_agents = len(critics)

        self.actor = actors
        self.critics = critics
        self.mi_w = mi_w

        with torch.no_grad():
            self.target_critics = [deepcopy(critic) for critic in self.critics]

        self.actor_optimizer = actor_optimizers
        self.critic_optimizers = critic_optimizers

        self.tau = tau
        self.gamma = gamma
        self.eta = eta

        # adaptive alpha setup for each agent
        self.target_entropy = -float(self.actor.action_dim/self.n_agents)
        self.log_alphas = [
            torch.tensor([0.0], dtype=torch.float32, device=self.device, requires_grad=True)
            for _ in range(self.n_agents)
        ]
        self.alpha_optimizer = torch.optim.Adam(self.log_alphas, lr=alpha_learning_rate)
            
        self.alphas = [log_alpha.exp().detach() for log_alpha in self.log_alphas]

    def _alpha_loss(self, state: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            action, action_log_prob = self.actor(state, need_log_prob=True)
            b = action.shape[0]
            action_log_prob = action_log_prob.reshape(b, self.n_agents, -1).sum(-1)
        loss = (-torch.stack(self.log_alphas).reshape(1,self.n_agents).repeat(b,1) * (action_log_prob + self.target_entropy)).mean()

        return loss

    def _actor_loss(self, state: torch.Tensor) -> Tuple[torch.Tensor, float, float]:
        action, action_log_prob = self.actor(state, need_log_prob=True)
        b = action.shape[0]
        action = action.reshape(b, self.n_agents, -1)
        action_log_prob = action_log_prob.reshape(b, self.n_agents, -1).sum(-1) 
        loss = torch.zeros(1).to(action)
        batch_entropy = torch.zeros(1).to(action)
        q_value_std = torch.zeros(1).to(action)
        for i, critic in enumerate(self.critics):
            q_value_dist = critic(state, action[:,i,:])
            assert q_value_dist.shape[0] == critic.num_critics
            q_value_min = q_value_dist.min(0).values
            # needed for logging
            q_value_std += q_value_dist.std(0).mean().item()
            batch_entropy += -action_log_prob[:,i].mean().item()

            assert action_log_prob[:,i].shape == q_value_min.shape
            loss += (self.alphas[i] * action_log_prob[:,i] - q_value_min).mean()

        return loss, batch_entropy, q_value_std

    def _critic_diversity_loss(
        self, state: torch.Tensor, action: torch.Tensor,  critic
    ) -> torch.Tensor:
        num_critics = critic.num_critics

        state = state.unsqueeze(0).repeat_interleave(num_critics, dim=0)
        action = (
            action.unsqueeze(0)
            .repeat_interleave(num_critics, dim=0)
            .requires_grad_(True)
        )
        q_ensemble = critic(state, action)

        q_action_grad = torch.autograd.grad(
            q_ensemble.sum(), action, retain_graph=True, create_graph=True
        )[0]
        q_action_grad = q_action_grad / (
            torch.norm(q_action_grad, p=2, dim=2).unsqueeze(-1) + 1e-10
        )
        q_action_grad = q_action_grad.transpose(0, 1)

        masks = (
            torch.eye(num_critics, device=self.device)
            .unsqueeze(0)
            .repeat(q_action_grad.shape[0], 1, 1)
        )
        q_action_grad = q_action_grad @ q_action_grad.permute(0, 2, 1)
        q_action_grad = (1 - masks) * q_action_grad

        grad_loss = q_action_grad.sum(dim=(1, 2)).mean()
        grad_loss = grad_loss / (num_critics - 1)

        return grad_loss

    def _critic_loss(
        self,
        state: torch.Tensor,
        action: torch.Tensor,
        reward: torch.Tensor,
        next_state: torch.Tensor,
        done: torch.Tensor,
        agents_prob, mi_w
    ) -> torch.Tensor:
        with torch.no_grad():
            next_action, next_action_log_prob = self.actor(
                next_state, need_log_prob=True
            )
        b = next_action.shape[0]
        next_action = next_action.reshape(b, self.n_agents, -1)
        next_action_log_prob = next_action_log_prob.reshape(b, self.n_agents, -1).sum(-1) 
        loss = torch.zeros(1).to(next_action)
        actor_probs = []
        actor_prob = self.actor.get_prob(state, action.repeat(1, self.n_agents))
        actor_prob = actor_prob.reshape(b, self.n_agents, -1)
        for i, critic in enumerate(self.critics):
            with torch.no_grad():
                q_next = self.target_critics[i](next_state, next_action[:,i,:]).min(0).values
                q_next = q_next - self.alphas[i] * next_action_log_prob[:,i]
                
                mutual_a = torch.exp(actor_prob[:,i,:]  + 1e-8).mean(-1).detach()
                #mutual_a2 = torch.zeros(mutual_a.shape).to(self.device)
                #for i in range(self.n_agents):
                #    mutual_a2 += torch.exp(actor_op[i].get_prob(state, action) + 1e-8).detach()
                
                mutual_a2 = agents_prob.sum(-1)
                #print("mutual_a.shape, mutual_a2.shape", mutual_a, mutual_a2)
                mutual = (mutual_a / (mutual_a2 + 1e-8))
                mutual = torch.clip(mutual, 0, 1)
                mutual = mutual.view(-1).detach().unsqueeze(-1)
                #mutual = torch.max(mutual, -1)[0].view(-1).detach().unsqueeze(-1)
                #print(mutual)
                assert q_next.unsqueeze(-1).shape == done.shape == reward.shape == mutual.shape
                q_target = reward + mi_w*mutual + self.gamma * (1 - done) * q_next.unsqueeze(-1)

            q_values = critic(state, action)
            critic_loss = ((q_values - q_target.view(1, -1)) ** 2).mean(dim=1).sum(dim=0)
            diversity_loss = self._critic_diversity_loss(state, action, critic)

            
            critic_loss = critic_loss + self.eta * diversity_loss
            actor_probs.append(mutual_a.view(-1).detach().unsqueeze(-1))

            self.critic_optimizers[i].zero_grad()
            critic_loss.backward()
            self.critic_optimizers[i].step()
            with torch.no_grad():
                soft_update(self.target_critics[i], self.critics[i], tau=self.tau)
            loss += critic_loss

        return loss, actor_probs

    def update(self, batch: TensorBatch) -> Dict[str, float]:
        state, action, reward, next_state, done, agents_prob = [arr.to(self.device) for arr in batch]

        update_info = {}
        #new_agent_probs = []

        # Alpha update
        alpha_loss = self._alpha_loss(state)
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()

        #self.alphas = self.log_alpha.exp().detach()
        self.alphas = [log_alpha.exp().detach() for log_alpha in self.log_alphas]

        # Actor update
        actor_loss, actor_batch_entropy, q_policy_std = self._actor_loss(state)
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # Critic update
        critic_loss, new_agent_probs = self._critic_loss(state, action, reward, next_state, done, agents_prob, self.mi_w)
        

        #  Target networks soft update
        
            # for logging, Q-ensemble std estimate with the random actions:
            # a ~ U[-max_action, max_action]
            #max_action = self.actor.max_action
            #random_actions = -max_action + 2 * max_action * torch.rand_like(action)

            #q_random_std = self.critic(state, random_actions).std(0).mean().item()

        update_info = {
            "alpha_loss": alpha_loss.item(),
            "critic_loss": critic_loss.item(),
            "actor_loss": actor_loss.item(),
            "batch_entropy": actor_batch_entropy,
            #"alpha": self.alpha.item(),
            "q_policy_std": q_policy_std,
            #"q_random_std": q_random_std,
        }

        return update_info, torch.stack(new_agent_probs, dim=1)

    def state_dict(self) -> Dict[str, Any]:
        state = {
            "actor": self.actor.state_dict(),
            "critics": [critic.state_dict() for critic in self.critics],
            "target_critics": [target_critic.state_dict() for target_critic in self.target_critics],
            "log_alphas": [log_alpha.item() for log_alpha in self.log_alphas],
            "actor_optim": self.actor_optimizer.state_dict(),
            #"actor_optimizers": [actor_optimizer.state_dict() for actor_optimizer in self.actor_optimizers],
            "critic_optimizers": [critic_optimizer.state_dict() for critic_optimizer in self.critic_optimizers],
            #"alpha_optimizers": [alpha_optimizer.state_dict() for alpha_optimizer in self.alpha_optimizers],
            "alpha_optim": self.alpha_optimizer.state_dict(),
        }
        return state

    def load_state_dict(self, state_dict: Dict[str, Any]):
        self.actor.load_state_dict(state_dict["actor"])
        self.actor_optimizer.load_state_dict(state_dict["actor_optim"])
        self.alpha_optimizer.load_state_dict(state_dict["alpha_optim"])
        for i in range(self.n_agents):
            #self.actors[i].load_state_dict(state_dict["actors"][i])
            self.critics[i].load_state_dict(state_dict["critics"][i])
            self.target_critics[i].load_state_dict(state_dict["target_critics"][i])
            #self.actor_optimizers[i].load_state_dict(state_dict["actor_optimizers"][i])
            self.critic_optimizers[i].load_state_dict(state_dict["critic_optimizers"][i])
            #self.alpha_optimizers[i].load_state_dict(state_dict["alpha_optimizers"][i])
            self.log_alphas[i].data[0] = state_dict["log_alphas"][i]
            self.alphas[i] = self.log_alphas[i].exp().detach()

    def soft_update(target, source, tau):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)


@torch.no_grad()
def eval_actor(
    env: gym.Env, actor: Actor, device: str, n_episodes: int, seed: int, idx_agent, n_agent, save_videos, task
) -> np.ndarray:
    env.seed(seed)
    actor.eval()
    episode_rewards = []

    # Create a directory for videos if it doesn't exist
    video_dir = "./videos"
    os.makedirs(video_dir, exist_ok=True)

    # Save videos for first episode
    frames = []
    for i in range(n_episodes):
        state, done = env.reset(), False
        episode_reward = 0.0
        while not done:
            if (i == 0 and save_videos):
                frame = env.render(mode="rgb_array")
                frames.append(frame)
            action = actor.act(state, device)
            action = action.reshape(n_agent, -1)
            action = action[idx_agent, :]
            state, reward, done, _ = env.step(action)
            episode_reward += reward
        episode_rewards.append(episode_reward)
    
    # Save videos
    if (save_videos):
        video_prefix = "agent"
        video_path = os.path.join(video_dir, f"{video_prefix}_{idx_agent}_{task}.mp4")
        media.write_video(video_path, frames, fps=30)
        print(f"Saved video to: {video_path}")

    actor.train()
    return np.array(episode_rewards)


def return_reward_range(dataset, max_episode_steps):
    returns, lengths = [], []
    ep_ret, ep_len = 0.0, 0
    for r, d in zip(dataset["rewards"], dataset["terminals"]):
        ep_ret += float(r)
        ep_len += 1
        if d or ep_len == max_episode_steps:
            returns.append(ep_ret)
            lengths.append(ep_len)
            ep_ret, ep_len = 0.0, 0
    lengths.append(ep_len)  # but still keep track of number of steps
    assert sum(lengths) == len(dataset["rewards"])
    return min(returns), max(returns)


def modify_reward(dataset, env_name, max_episode_steps=1000):
    if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
        min_ret, max_ret = return_reward_range(dataset, max_episode_steps)
        dataset["rewards"] /= max_ret - min_ret
        dataset["rewards"] *= max_episode_steps
    elif "antmaze" in env_name:
        dataset["rewards"] -= 1.0


@pyrallis.wrap()
def train(config: TrainConfig):
    set_seed(config.train_seed, deterministic_torch=config.deterministic_torch)
    wandb_init(asdict(config))

    # data, evaluation, env setup
    eval_env = wrap_env(gym.make(config.env_name))
    state_dim = eval_env.observation_space.shape[0]
    action_dim = eval_env.action_space.shape[0]

    

    if 'diverse' in config.env_name:
        num_agents = 9
        env = gym.make(config.env_name)
        d4rl_dataset = d4rl.qlearning_dataset(env, dataset=env.get_dataset(h5path='./DiveOff/dataset/'+ config.env_name +'.hdf5'))
    else:
		
        d4rl_dataset = d4rl.qlearning_dataset(eval_env)
        num_agents = 5

    if config.normalize_reward:
        modify_reward(d4rl_dataset, config.env_name)

    buffer = ReplayBuffer(
        state_dim=state_dim,
        action_dim=action_dim,
        buffer_size=config.buffer_size,
        device=config.device,
        n_agent=num_agents,
    )
    buffer.load_d4rl_dataset(d4rl_dataset)

    #actors = []
    #actor_optimizers = []
    #critics = []
    #critic_optimizers = []

    
    actor = Actor(state_dim, action_dim*num_agents, config.hidden_dim, config.max_action)
    actor.to(config.device)
    actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.actor_learning_rate)
    #actors.append(actor)
    #actor_optimizers.append(actor_optimizer)

    #for _ in range(num_agents):
        # Actor setup

        # Critic setup
    #    critic = VectorizedCritic(state_dim, action_dim, config.hidden_dim, config.num_critics)
    #    critic.to(config.device)
    #    critic_optimizer = torch.optim.Adam(critic.parameters(), lr=config.critic_learning_rate)
    #    critics.append(critic)
    #    critic_optimizers.append(critic_optimizer)
    
    first_critic = VectorizedCritic(state_dim, action_dim, config.hidden_dim, config.num_critics)
    first_critic.to(config.device)
    first_critic_optimizer = torch.optim.Adam(first_critic.parameters(), lr=config.critic_learning_rate)

    critics = [first_critic]
    critic_optimizers = [first_critic_optimizer]

    # Copy the initial weights to the rest of the critics
    for _ in range(num_agents - 1):
        # Critic setup with the same initial weights
        critic = VectorizedCritic(state_dim, action_dim, config.hidden_dim, config.num_critics)
        critic.to(config.device)
        
        # Copy weights from the first critic
        critic.load_state_dict(first_critic.state_dict())

        # Initialize optimizer for the new critic
        critic_optimizer = torch.optim.Adam(critic.parameters(), lr=config.critic_learning_rate)
        
        critics.append(critic)
        critic_optimizers.append(critic_optimizer)


    # Initialize the EDAC trainer with lists of actors, critics, and their optimizers
    trainer = EDAC(
        actors=actor,
        actor_optimizers=actor_optimizer,
        critics=critics,
        critic_optimizers=critic_optimizers,
        gamma=config.gamma,
        tau=config.tau,
        eta=config.eta,
        alpha_learning_rate=config.alpha_learning_rate,
        device=config.device,
        mi_w=config.mi_w
    )
    # saving config to the checkpoint
    if config.checkpoints_path is not None:
        print(f"Checkpoints path: {config.checkpoints_path}")
        os.makedirs(config.checkpoints_path, exist_ok=True)
        with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
            pyrallis.dump(config, f)

    total_updates = 0.0
    for epoch in trange(config.num_epochs, desc="Training"):
        # training
        for _ in trange(config.num_updates_on_epoch, desc="Epoch", leave=False):
            batch = buffer.sample(config.batch_size)
            update_info, agents_prob = trainer.update(batch[:-1])
            #print("agents_prob", agents_prob)
            buffer.add_agent_prob(agents_prob.squeeze(), batch[-1])

            if total_updates % config.log_every == 0:
                wandb.log({"epoch": epoch, **update_info})

            total_updates += 1

        # evaluation
        if epoch % config.eval_every == 0 or epoch == config.num_epochs - 1:
            eval_log = {"epoch": epoch}

            # Save videos only after every 200 epochs
            if epoch % 200 == 0:
                save_videos = False # disable for now
            else:
                save_videos = False
                
            # Evaluate each actor
            for i, critic in enumerate(critics):
                eval_returns = eval_actor(
                    env=eval_env,
                    actor=actor,
                    n_episodes=config.eval_episodes,
                    seed=config.eval_seed,
                    device=config.device,
                    idx_agent=i,
                    n_agent=num_agents,
                    save_videos=save_videos,
                    task=config.env_name
                )

                eval_log[f"eval/reward_mean_{i}"] = np.mean(eval_returns)
                eval_log[f"eval/reward_std_{i}"] = np.std(eval_returns)

                if hasattr(eval_env, "get_normalized_score"):
                    normalized_score = eval_env.get_normalized_score(eval_returns) * 100.0
                    eval_log[f"eval/normalized_score_mean_{i}"] = np.mean(normalized_score)
                    eval_log[f"eval/normalized_score_std_{i}"] = np.std(normalized_score)

            wandb.log(eval_log)
            if epoch % 500 == 0:
                if config.checkpoints_path is not None:
                    torch.save(
                        trainer.state_dict(),
                        os.path.join(config.checkpoints_path, f"MI(one_SameCriticW)_{epoch}.pt"),
                    )
            if config.checkpoints_path is not None:
                    torch.save(
                        trainer.state_dict(),
                        os.path.join(config.checkpoints_path, f"MI(one_SameCriticW)_last.pt"),
                    )

    wandb.finish()


if __name__ == "__main__":
    train()
