import os
from time import time

import torch
from torch.distributions.categorical import Categorical
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from env_list import env_action_dict

ctx = (torch.amp.autocast(device_type='cuda', dtype=torch.float16))

class WMEnv(nn.Module):
    def __init__(self, model, reward_head, done_head, env_name):
        super().__init__()
        self.model = model
        self.model.eval()
        self.reward_head = reward_head
        self.reward_head.eval()
        self.done_head = done_head
        self.done_head.eval()
        self.env_name = env_name
        self.A = len(env_action_dict[env_name])
        self.encs = None
        self.actions = None
        self.rewards = None
        self.dones = None
        seed = int(f'{time():.10f}'[-9:][::-1])
        self.generator = np.random.RandomState(seed)
    
    def reset(self, device, env_names):
        self.kv_cache = None

        enc_list = []
        act_list = []
        reward_list = []
        done_list = []
        burn_in = 16

        collect_frames = 90000

        for env_name in env_names:
            data_dir = '../{}/{}'.format('data' if self.model.config.input_dim == 512 else 'data_big', env_name)
            fn = os.path.join(data_dir, 'encodings_train.npy')
            enc = np.memmap(fn, dtype=np.float32, mode='r', shape=(collect_frames, self.model.config.input_dim))
            fn = os.path.join(data_dir, 'encodings_train.npy')
            end = self.generator.randint(burn_in, collect_frames)
            start = end - burn_in
            enc = torch.tensor(np.array(enc[start:end]), device=device)
            fn_action = fn.replace('encodings', 'actions')
            act = np.memmap(fn_action, dtype=np.uint8, mode='r', shape=(collect_frames,))
            act = torch.tensor(np.array(act[start + 1 : end]), device=device).to(torch.int64)
            fn_reward = fn.replace('encodings', 'rewards')
            reward = np.memmap(fn_reward, dtype=np.float32, mode='r', shape=(collect_frames,))
            reward = torch.tensor(np.array(reward[start + 1 : end]), device=device)
            reward = torch.sign(reward).to(torch.int64)
            fn_done = fn.replace('encodings', 'dones')
            done = np.memmap(fn_done, dtype=bool, mode='r', shape=(collect_frames,))
            done = torch.tensor(np.array(done[start + 1 : end]), device=device).to(torch.int64)
            assert len(enc) == act.size(0) + 1
            assert len(enc) == burn_in
            enc_list.append(enc)
            act_list.append(act)
            reward_list.append(reward)
            done_list.append(done)

        self.encs = torch.stack(enc_list)
        self.actions = torch.stack(act_list)
        self.rewards = torch.stack(reward_list)
        self.dones = torch.stack(done_list)
        self.device = device
        return
    
    def step(self, actions):
        self.actions = torch.cat((self.actions, actions.unsqueeze(-1)), dim=-1)

        input_enc = self.encs[:, -4:].view(len(self.encs), -1)
        input_act = F.one_hot(actions, self.A)
        input_rewards = self.rewards[:, -4:].float()
        rewards = torch.argmax(self.reward_head(input_enc, input_act.float(), input_rewards), dim=-1)
        if self.env_name in ["Pong", "Boxing"]:
            rewards -= 1

        dones = torch.argmax(self.done_head(input_enc, input_act.float(), input_rewards), dim=-1)

        self.rewards = torch.cat((self.rewards, rewards.unsqueeze(-1)), dim=-1)
        self.dones = torch.cat((self.dones, dones.unsqueeze(-1)), dim=-1)
        context_vecs = self.get_context(option=0)
        next_encs = self.model.enc_head(context_vecs)
        self.encs = torch.cat((self.encs, next_encs.unsqueeze(1)), dim=1)
        return rewards, dones
    
    def get_context(self, option):
        """
        option 0: Predict next obs (enc)
        option 1: Predict rewards
        option 2: Predict actions
        """
        assert option in (0, 1, 2)
        config = self.model.config
        context_length = 144
        encs = self.encs[:, -context_length:]
        acts = self.actions[:, -context_length:] if option in (0, 1) else self.actions[:, -context_length + 1:]
        reward_and_dones = self.rewards + 1
        reward_and_dones[torch.where(self.dones)] += 3
        rewards = reward_and_dones[:, -context_length:] if option == 0 else reward_and_dones[:, -context_length + 1:]

        if option == 0: assert encs.size(1) == acts.size(1) == rewards.size(1)
        elif option == 1: assert encs.size(1) == acts.size(1) == rewards.size(1) + 1
        else: assert encs.size(1) == acts.size(1) + 1 == rewards.size(1) + 1

        if self.kv_cache is None:
            self.kv_cache = torch.zeros((config.n_layer, 2, len(encs), context_length*3, config.n_embd), device=self.device)

        cnt = torch.where(self.kv_cache[0, 0, 0, :].mean(-1) == 0)[0]
        assert len(cnt) != 0
        cnt = cnt.min().item()

        with ctx:
            with torch.no_grad():
                x, self.kv_cache = self.model.compute_context_vector_kvc(encs, acts, rewards, self.kv_cache, cnt)
        context_vecs = x[:, -1, :]

        return context_vecs