import numpy as np

from algo.utils.wsre import wsre
from .utils.buffer import ReplayMemory
from .utils.disc import DiscTrainer
from .utils.sac import SACTrainer

class WassersteinTrainer:
    def __init__(self, obs_shape, action_space, rc, src, args) -> None:
        self.memories = [ReplayMemory(args.buffer_limit) for _ in range(args.num_modes)]
        self.disc_memory = ReplayMemory(args.buffer_limit)
        # self.disc_memory = None
        self.trajectory = []
        self.trainers = [SACTrainer(obs_shape, action_space, args) for _ in range(args.num_modes)]
        self.disc_trainer = DiscTrainer(obs_shape, args)
        # self.disc_trainer = None
        self.total_episodes = 0
        self.step_t = 0
        # self.sr_min = np.log(1/args.num_modes)*10
        self.args = args
        self.rc = rc
        self.src = src

    def start_episode(self, label):
        self.current_id = label
        self.current_trainer = self.trainers[label]
        self.current_memory = self.memories[label]
        self.trajectory = []
        self.obs_list = []
        self.reward_list = []

    def act(self, obs):
        a, logprob = self.current_trainer.act(obs)
        if len(a.shape) > 1:
            a = a[0]
        return a, logprob

    def record(self, obs, action, logprob, reward, new_obs, mask):
        self.trajectory.append([obs, action, logprob, reward, new_obs, mask])
        self.obs_list.append(obs)
        self.reward_list.append(reward)
        self.step_t += 1
         
    def end_episode(self):
        assert self.step_t == len(self.trajectory)
        labels = np.array([self.current_id for _ in range(self.step_t)])
        score = self.disc_trainer.score(np.stack(self.obs_list, axis=0), labels)
        avg_acc = np.mean(np.exp(score))
        state_batch = list(map(np.stack, zip(*self.trajectory)))[0]
        srs_list = []
        sum_srs = []
        for i in range(self.args.num_modes):
            if i!=self.current_id:
                srs = np.zeros(self.step_t)
                if len(self.memories[i]) > self.args.max_episode_len: # Ensure we have a non-empty target batch
                    target_state_batch = list(self.memories[i].dump(self.args.max_episode_len))[0]
                    srs = wsre(state_batch, target_state_batch)
                sum_srs.append(np.sum(srs))
                srs_list.append(srs)
        min_dist_idx = np.argmin(sum_srs)
        sr = srs_list[min_dist_idx]
        original_episode_return = np.sum(self.reward_list)
        original_episode_sr = np.sum(sr)
        episode_return = 0.0
        for i in range(self.step_t):
            transition = self.trajectory[i]
            transition[3] = transition[3] * self.rc + sr[i] * self.src
            episode_return += transition[3]
            self.current_memory.push(transition)
            self.disc_memory.push((labels[i], transition[0]))
        
        self.step_t = 0
        self.total_episodes += 1
        self.trajectory = []
        self.obs_list = []
        self.reward_list = []
        return original_episode_return, original_episode_sr, episode_return, avg_acc

    def update_policy(self, updates):
        state_batch, action_batch, logprob_batch, reward_batch, next_state_batch, mask_batch = self.current_memory.sample(batch_size=self.args.batch_size)
        c1_loss, c2_loss, p_loss, ent_loss, alpha = self.current_trainer.update_parameters((state_batch, action_batch, logprob_batch, reward_batch, next_state_batch, mask_batch), updates)
        return c1_loss, c2_loss, p_loss, ent_loss, alpha

    def can_update_policy(self):
        return len(self.current_memory) > self.args.batch_size

    def update_disc(self):
        label_batch, state_batch = self.disc_memory.sample(batch_size=self.args.disc_batch_size)
        d_loss = self.disc_trainer.update_parameters((label_batch, state_batch))
        return d_loss
        # return 0.0
    
    def can_update_disc(self):
        return len(self.disc_memory) > self.args.disc_batch_size

    def save_models(self):
        for i in range(len(self.trainers)):
            self.trainers[i].save_model(env_name=self.args.scenario, suffix="{}".format(i))