import json
import os
import torch
import numpy as np
from typing import Union
from buffers.buffer_continuous import ReplayBuffer
from algos.misodice_discrete import DemoDICE as DemoDICEDiscrete
from algos.misodice_continuous import DemoDICE as DemoDICEContinuous
from configs import Args, MAMUJOCO_ENV_NAMES, SMACV1_ENV_NAMES, SMACV2_ENV_NAMES


class Trainer:

    def __init__(self, policy: Union[DemoDICEDiscrete, DemoDICEContinuous], config: Args):
        self.config = config
        self.device = config.device
        self.critic_lr = config.critic_lr
        self.actor_lr = config.actor_lr
        self.n_minibatches = config.n_minibatches
        self.env_name = config.env_name
        self.seed = config.seed
        self.exsize = config.exsize
        self.algo = config.algo
        if self.algo == "bc":
            self.algo = f"{self.algo}_beta{config.alpha:.1f}"
        if self.algo == "misodice" and config.alpha != 0.05:
            self.algo = f"{self.algo}_alpha{config.alpha:.2f}"

        if config.use_llm:
            self.logdir = f"logs/{self.algo}/{self.env_name}_llm/seed{self.seed}/exsize{self.exsize}"
        else:
            self.logdir = f"logs/{self.algo}/{self.env_name}/seed{self.seed}/exsize{self.exsize}"
        os.makedirs(self.logdir, exist_ok=True)
        self.results = dict()

        self.policy = policy.to(self.device)
        self.cost_optimizer = torch.optim.Adam(self.policy.cost.parameters(), lr=self.critic_lr)
        self.critic_optimizer = torch.optim.Adam(self.policy.critic.parameters(), lr=self.critic_lr)
        self.actor_optimizer = torch.optim.Adam(self.policy.actor.parameters(), lr=self.actor_lr)

    def update(self, init_states, init_obs, expert_transition, union_transition):
        self.cost_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        self.actor_optimizer.zero_grad()

        cost_loss, nu_loss, pi_loss = self.policy.compute_dice_loss(init_states, init_obs, expert_transition, union_transition)

        self.cost_optimizer.step()
        self.critic_optimizer.step()
        self.actor_optimizer.step()

        return cost_loss.item(), nu_loss.item(), pi_loss.item()
    
    def load_env(self, actor):
        if self.env_name in MAMUJOCO_ENV_NAMES:
            from envs.mamujoco.env import MaMujocoWrapper as EnvWrapper
            from rollouts.rollout_continuous import RolloutWorkerContinuous as RolloutWorker
            rollout_worker = RolloutWorker(actor, self.n_agents, device=self.device)
        elif self.env_name in SMACV1_ENV_NAMES:
            from envs.smacv1.env import SMACWrapper as EnvWrapper
            from rollouts.rollout_discrete import RolloutWorkerDiscrete as RolloutWorker
            rollout_worker = RolloutWorker(actor, self.n_agents, device=self.device)
        elif self.env_name in SMACV2_ENV_NAMES:
            from envs.smacv2.env import SMACWrapper as EnvWrapper
            from rollouts.rollout_discrete import RolloutWorkerDiscrete as RolloutWorker
            rollout_worker = RolloutWorker(actor, self.n_agents, device=self.device)
        else:
            raise NotImplementedError
        env = EnvWrapper(self.env_name, self.seed)
        return env, rollout_worker

    def write_results(self):
        with open(os.path.join(self.logdir, "results.json"), "w") as f:
            json.dump(self.results, f, indent=4)

    def eval(self, actor, step=0):
        env, rollout_worker = self.load_env(actor)

        rollout_worker.obs_scale = self.obs_scale.to(self.device)
        rollout_worker.obs_shift = self.obs_shift.to(self.device)

        results = rollout_worker.rollout(env, num_episodes=32, verbose=False)
        self.results[f"step_{step}"] = results

        self.write_results()

        returns = results["returns"]
        returns_mean = np.mean(returns)
        returns_std = np.std(returns)
        print(f"Step: {step} - Evaluation results: {returns_mean:.2f} ± {returns_std:.2f}")
        env.close()
    
    def train(self, buffer: ReplayBuffer, n_epochs=100, n_evals=100):
        print("Training...")

        self.n_agents = buffer.n_agents
        self.obs_scale = buffer.obs_scale
        self.obs_shift = buffer.obs_shift

        total_steps = n_epochs * self.n_minibatches
        eval_step = total_steps // n_evals

        global_step = 0
        losses = None
        for _ in range(n_epochs + 1):
            for union_init_states, union_init_obs, expert_transition, union_transition in buffer.sample(self.n_minibatches, device=self.device):
                if global_step % (eval_step // 4) == 0:
                    print("Step:", global_step, "Losses:", losses)
                if global_step % eval_step == 0:
                    self.eval(self.policy.actor, global_step // eval_step)
                    if global_step // eval_step > n_evals:
                        return
                    
                losses = self.update(union_init_states, union_init_obs, expert_transition, union_transition)
                global_step += 1