import ray
from torch.nn.parallel import DistributedDataParallel as DDP
from rl.algorithms.ppo import PPO
import torch.distributed as dist
from torch import optim
from torch import nn
from tqdm import tqdm
from typing import List, Dict, Any

from utils.train_utils import ExceptionCatcher
import torch.autograd
import torch as t


@ray.remote
class TrainerWorker:
    def __init__(self):
        self.rank = None
        self.actor = None
        self.critic = None
        self.ppo_accumulate_steps = None
        self.ppo_train_steps_per_worker = None
        self.world_size = None
        self.nccl_port = None
        self.ppo = None

    def set_parameters(
        self,
        rank: int,
        actor,
        critic,
        ppo_batch_size,
        actor_lr,
        critic_lr,
        actor_routing_entropy_weight,
        actor_routing_diversity_weight,
        ppo_accumulate_steps,
        ppo_train_steps_per_worker,
        world_size: int,
        nccl_port: int,
    ):
        with ExceptionCatcher() as _:
            self.rank = rank
            # For Multi-GPU training
            self.actor = actor.to("cuda:0")
            self.critic = critic.to("cuda:0")
            self.ppo_accumulate_steps = ppo_accumulate_steps
            self.ppo_train_steps_per_worker = ppo_train_steps_per_worker
            dist.init_process_group(
                backend="nccl",
                init_method=f"tcp://localhost:{nccl_port}",
                rank=rank,
                world_size=world_size,
            )

            self.ppo = PPO(
                DDP(self.actor),
                DDP(self.critic),
                optim.AdamW,
                nn.MSELoss(),
                shared_parameters_belong_to_optimizer="critic",
                batch_size=ppo_batch_size,
                actor_learning_rate=actor_lr,
                critic_learning_rate=critic_lr,
                gae_lambda=None,
                entropy_weight=0.01,
                routing_entropy_weight=actor_routing_entropy_weight,
                routing_diversity_weight=actor_routing_diversity_weight,
                discount=0.9,
            )

    def store_episodes(self, worker_episodes: List[Dict[str, Any]]):
        with ExceptionCatcher() as _:
            for episode in worker_episodes:
                self.ppo.store_episode(episode["episode"], concatenate_samples=False)

    def get_parameters(self, save_optimizer_state: bool = False):
        with ExceptionCatcher() as _:
            actor_state = self.ppo.actor.module.state_dict()
            critic_state = self.ppo.critic.module.state_dict()
            cpu_actor_state = {k: v.cpu() for k, v in actor_state.items()}
            cpu_critic_state = {k: v.cpu() for k, v in critic_state.items()}

            if save_optimizer_state:

                def tensor_to_cpu(obj):
                    if isinstance(obj, t.Tensor):
                        return obj.cpu()
                    elif isinstance(obj, dict):
                        return {k: tensor_to_cpu(v) for k, v in obj.items()}
                    elif isinstance(obj, list):
                        return [tensor_to_cpu(v) for v in obj]
                    else:
                        return obj

                actor_optimizer_state = self.ppo.actor_optim.state_dict()
                critic_optimizer_state = self.ppo.critic_optim.state_dict()
                cpu_actor_optimizer_state = tensor_to_cpu(actor_optimizer_state)
                cpu_critic_optimizer_state = tensor_to_cpu(critic_optimizer_state)
                return (
                    cpu_actor_state,
                    cpu_critic_state,
                    cpu_actor_optimizer_state,
                    cpu_critic_optimizer_state,
                )
            else:
                return cpu_actor_state, cpu_critic_state

    def train_loop(self):
        torch.autograd.set_detect_anomaly(True)
        with ExceptionCatcher() as _:
            epoch_loss = 0
            step_actor_loss_list = []
            step_critic_loss_list = []

            for step in tqdm(
                range(self.ppo_train_steps_per_worker),
                desc="PPO Training",
                total=self.ppo_train_steps_per_worker,
                disable=self.rank != 0,
            ):

                step_actor_loss = 0
                step_critic_loss = 0
                for optimizer in self.ppo.optimizers:
                    optimizer.zero_grad()
                for acc_step in range(self.ppo_accumulate_steps):
                    a_loss, c_loss = self.ppo.get_loss(concatenate_samples=False)
                    step_loss = (a_loss + c_loss) / self.ppo_accumulate_steps
                    step_loss.backward()

                    step_actor_loss += a_loss.item()
                    step_critic_loss += c_loss.item()
                    epoch_loss += step_loss.item()

                step_actor_loss /= self.ppo_accumulate_steps
                step_critic_loss /= self.ppo_accumulate_steps

                # print(f"step {step} step_actor_loss: {step_actor_loss}, step_critic_loss: {step_critic_loss}")

                if self.rank == 0:
                    step_actor_loss_list.append(float(step_actor_loss))
                    step_critic_loss_list.append(float(step_critic_loss))

                nn.utils.clip_grad_norm_(self.ppo.actor.parameters(), 4)
                nn.utils.clip_grad_norm_(self.ppo.critic.parameters(), 4)
                for optimizer in self.ppo.optimizers:
                    optimizer.step()

            self.ppo.finish_update()

            epoch_loss = epoch_loss / (
                self.ppo_train_steps_per_worker * self.ppo_accumulate_steps
            )

            epoch_loss = float(epoch_loss)

            training_performance_dict = {
                "epoch_loss": epoch_loss,
                "step_actor_loss_list": step_actor_loss_list,
                "step_critic_loss_list": step_critic_loss_list,
            }

            return training_performance_dict
