import copy
import os

import einops
import torch
from ml_logger import logger

import diffuser

from .arrays import apply_dict, batch_to_device, to_device, to_np
from .timer import Timer

from datetime import datetime

import importlib
import json


import torch.nn.utils.prune as prune

def cycle(dl):
    while True:
        for data in dl:
            yield data


class GenerateBuffer:
    def __init__(self, max_size=1000):
        self.s_buffer = []
        self.a_buffer = []
        self.r_buffer = []
        self.d_buffer = []
        self.max_size = max_size

    def add(self, s, a, r, d):
        self.s_buffer.append(s)
        self.a_buffer.append(a)
        self.r_buffer.append(r)
        self.d_buffer.append(d)
            

    def __len__(self):
        return len(self.s_buffer)
    

class Generator(object):
    def __init__(
        self,
        diffusion_model,
        dataset,

        accept_threshold = 0.02,
        generate_batch_size = 100,
        generate_episode_nums = 4000,
        
        env_type = "smac",
        generate_device="cuda",

    ):
        
        super().__init__()
        self.model = diffusion_model
        self.dataset = dataset

        self.accept_threshold = accept_threshold
        self.generate_batch_size = generate_batch_size
        self.generate_episode_nums = generate_episode_nums
        
        self.env_type = env_type
        self.device = generate_device

        self.gen_buffer = GenerateBuffer(max_size=self.generate_episode_nums)

        self.gen_dataloader = cycle(
            torch.utils.data.DataLoader(
                self.dataset,
                batch_size=self.generate_batch_size,
                num_workers=0,
                shuffle=True,
                pin_memory=True,
            )
        )


        

    def generate_sample(self, batch, returns, generate, rt_level):
        obs = batch["cond"]["x"]
        attention_masks = torch.zeros(
            (obs.shape[0], self.model.horizon + self.model.history_horizon, self.model.n_agents, 1)
        )
        attention_masks[:, self.model.history_horizon :] = 1.0
        attention_masks = attention_masks.to(device=self.device)

        shape = (
            obs.shape[0],
            self.model.horizon + self.model.history_horizon,
            obs.shape[2],
            obs.shape[3]
        )  # b t a f

        cond_trajectories = torch.zeros(shape, device=self.device)
        cond_trajectories[:, : self.model.history_horizon + 1] = obs[:, : self.model.history_horizon + 1]
        conditions = {
            "x": cond_trajectories.to(device=self.device),
            "masks": batch["cond"]["masks"]
        }

        env_ts = to_device(
            torch.arange(self.model.horizon + self.model.history_horizon)
            - self.model.history_horizon,
            self.device,
        )
        env_ts = einops.repeat(env_ts, "t -> b t", b=obs.shape[0])


    #### Generate sequentially using subgoal ####
        last_agent_values = self.model.value_model(obs[:, -1]).squeeze()
        seq_samples = obs

        list_mask = []
        for i in range(self.model.n_agents):
            #seq_samples = obs
            chosen_agent = torch.argsort(last_agent_values, dim=-1, descending=True)[:, i]
            agent_onehot = torch.eye(self.model.n_agents, device=self.device)
            chosen_agent_onehot = agent_onehot[chosen_agent, :]
            
            mask = 1 - chosen_agent_onehot.unsqueeze(1).unsqueeze(-1).expand_as(obs)
            
            mask[:, :self.model.history_horizon + 1, :, :] = 1 

        ### subgoal ###
            agent_obs_values = self.model.value_model(obs).squeeze()
            agent_obs_values[:, :self.model.history_horizon, :] = -1e6
            max_time_idx = torch.max(agent_obs_values, dim=1)[1]
            ## Generate mask ##
            subgoal_mask = torch.zeros_like(obs, dtype=torch.int)
            time_range = torch.arange(obs.shape[1]).view(1, -1, 1, 1).expand(obs.shape[0], -1, obs.shape[2], obs.shape[3])
            subgoal_mask = (time_range.cuda() <= max_time_idx.view(obs.shape[0], 1, obs.shape[2], 1)).int()

            chosen_time_idx = max_time_idx[torch.arange(max_time_idx.shape[0]), chosen_agent]
            
            subgoal_mask[:, 0, :, :] = 1  ## First state 
            mask = mask.bool() + subgoal_mask.bool()
            mask = mask.int()
        
        ### subgoal ###

            list_mask.append(mask)
            original_conditions = conditions
            
            repaint_conditions = {
            "x": seq_samples,
            "masks": mask.to(bool)
            }

            seq_samples = self.model.repaint_conditional_sample(
                original_conditions,
                repaint_conditions,
                returns=returns,
                env_ts=env_ts,
                attention_masks=attention_masks,
            )
                
            seq_samples = seq_samples[:, self.model.history_horizon :]
            mask = torch.ones_like(list_mask[0])
            for i in range(len(list_mask)):
                mask *= list_mask[i] 

        ################### TODO: Action masking ################
        
        pred_act = self.model.inv_model(torch.cat([seq_samples[:, :-1], seq_samples[:, 1:]], dim=-1))
        
        if self.env_type.startswith("smac"):
            num_actions = pred_act.shape[-1]
            pred_act_idx = torch.argmax(pred_act, dim=-1, keepdim=True)
            pred_tot_act = torch.nn.functional.one_hot(pred_act_idx.squeeze().to(torch.long), num_classes=num_actions)
        else:
            pred_tot_act = pred_act

        
        pred_st = self.model.dynamic_model(torch.cat([seq_samples[:, :-1], pred_tot_act], dim=-1))
        pred_reward = self.model.predict_reward(torch.cat([seq_samples[:, :-1], pred_tot_act], dim=-1))
        batch_dynamic_acc = torch.nn.functional.mse_loss(seq_samples[:, 1:, :, :], pred_st , reduction="none")
            
        return batch_dynamic_acc.mean(dim=-1), seq_samples, pred_tot_act, pred_reward, mask


    def generate_episodes(self):
        batch = next(self.gen_dataloader)
        batch = batch_to_device(batch, device=self.device)


        returns = torch.ones_like(batch["returns"]).to(device=self.device)
        with torch.no_grad():
            dynamic_acc, s, a, r, m = self.generate_sample(batch, returns, generate=True, rt_level=None)
        d = torch.zeros_like(r) 

        for i in range(dynamic_acc.shape[0]):   
            rews = batch["rewards"][i].mean(dim=1)
            discounted_return = torch.round(torch.sum(torch.pow(0.99, torch.arange(rews.shape[0]).cuda()) * rews.squeeze()))
            generated_rews = r[i].mean(dim=1)
            discounted_generated_return = torch.round(torch.sum(torch.pow(0.99, torch.arange(generated_rews.shape[0]).cuda()) * generated_rews.squeeze()))
            mask = m[:, :, :, 0]
            subgoal_dynamic = (1 - mask[:, self.model.history_horizon+1:]) * dynamic_acc
            if subgoal_dynamic[i].mean() <= self.accept_threshold:
                self.gen_buffer.add(s[i], a[i], r[i], d[i])




    def load(self):
        """
        loads model and ema from disk
        """

        loadpath = os.path.join(self.bucket, logger.prefix, "checkpoint/state.pt")
        data = torch.load(loadpath)

        self.step = data["step"]
        self.model.load_state_dict(data["model"])


   