from collections import OrderedDict
from typing import Any, List

import numpy as np
import torch
from pytorch_lightning import LightningModule
from torch.nn import functional as F
 
from models.components.mingpt import GPT, GPTConfig 
from models.utils import (
    get_exp_return_dmc,
    get_min_action_dmc,
    top_k_logits,
)


class DTLitModule(LightningModule):
    """Example of LightningModule for MNIST classification.

    A LightningModule organizes your PyTorch code into 5 sections:
        - Computations (init).
        - Train loop (training_step)
        - Validation loop (validation_step)
        - Test loop (test_step)
        - Optimizers (configure_optimizers)

    Read the docs:
        https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
    """

    def __init__(
        self,
        game,
        dmc,
        domain,
        task,
        agent_type,
        model_type,
        timestep=10000,
        n_embd=128,
        lr=6e-4,
        forward=False,
        inverse=False,
        reward=False,
        rand_inverse=False,
        rand_mask_size=1,
        freeze_encoder=False,
        patch_size=14,
        context_length=30,
        betas=(0.9, 0.95),
        weight_decay=0.1,
        n_layer=6,
        n_head=8,
        pred_layers=1,
        stack_size=4,
        **kwargs,
    ):
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # it also ensures init params will be stored in ckpt
        self.save_hyperparameters()
        if self.hparams.dmc:
            channels = 3
            vocab_size = get_min_action_dmc(self.hparams.domain)
        else:
            raise NotImplementedError

        if self.hparams.agent_type == "gpt":
            block_size = self.hparams.context_length * 2
            mconf = GPTConfig(
                vocab_size,
                block_size,
                max_timestep=self.hparams.timestep,
                channels=channels,
                model_type=self.hparams.model_type,
                n_layer=self.hparams.n_layer,
                n_head=self.hparams.n_head,
                n_embd=self.hparams.n_embd,
                cont_action=self.hparams.dmc,
                pred_layers=self.hparams.pred_layers,
                rtg_layers=self.hparams.rtg_layers,
                bc_layers=self.hparams.bc_layers,
            )
            self.net = GPT(mconf)
            print(self.net)
        else:
            assert "agent type not supported"

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("DTModel")

        parser.add_argument("--agent_type", type=str, default="gpt")
        parser.add_argument(
            "--model_type", type=str, default="reward_conditioned", choices=["reward_conditioned", "naive"]
        )
        parser.add_argument("--n_embd", type=int, default=128)
        parser.add_argument("--lr", type=float, default=6e-4)
        parser.add_argument("--unsupervise", default=False, action="store_true")
        parser.add_argument("--forward", default=False, action="store_true")
        parser.add_argument("--inverse", default=False, action="store_true")
        parser.add_argument("--reward", default=False, action="store_true")
        parser.add_argument("--rand_inverse", default=False, action="store_true")
        parser.add_argument("--freeze_encoder", default=False, action="store_true")
        parser.add_argument("--rand_attn_only", default=False, action="store_true")
        parser.add_argument("--rand_mask_size", type=int, default=30)
        parser.add_argument("--mask_obs_size", type=int, default=0)
        parser.add_argument("--n_layer", type=int, default=6)
        parser.add_argument("--n_head", type=int, default=8)

        # weights
        parser.add_argument("--forward_weight", type=float, default=1.0)

        # layers
        parser.add_argument("--rtg_layers", type=int, default=1)
        parser.add_argument("--bc_layers", type=int, default=1)
        parser.add_argument("--pred_layers", type=int, default=1)
        # STGPT
        parser.add_argument("--patch_agg", type=str, default="last")
        parser.add_argument("--patch_size", type=int, default=14)

        # DMC
        parser.add_argument("--dmc", default=False, action="store_true")

        return parent_parser

    def load_my_checkpoint(self, path, no_action=False, strict=True, no_action_head=False):
        m = torch.load(path)["state_dict"]
        model_dict = self.state_dict()
        for k in m.keys():
            if no_action:
                if (
                    "reward_conditioned_head" in k
                    or "naive_head" in k
                    or "inverse_pred_head" in k
                    or "rand_inverse_pred_head" in k
                    or "action_encoder" in k
                    or "tok_emb" in k
                ):
                    continue
            if no_action_head:
                if "reward_conditioned_head" in k:
                    continue

            if k in model_dict:
                pname = k
                pval = m[k]
                model_dict[pname] = pval.clone().to(model_dict[pname].device)

        self.load_state_dict(model_dict, strict=strict)

    # def on_load_checkpoint(self, checkpoint: dict):
    #     state_dict = checkpoint["state_dict"]
    #     model_state_dict = self.state_dict()
    #     is_changed = False
    #     for k in state_dict:
    #         if k in model_state_dict:
    #             if state_dict[k].shape != model_state_dict[k].shape:
    #                 print(f"Skip loading parameter: {k}, "
    #                             f"required shape: {model_state_dict[k].shape}, "
    #                             f"loaded shape: {state_dict[k].shape}")
    #                 state_dict[k] = model_state_dict[k]
    #                 is_changed = True
    #         else:
    #             print(f"Dropping parameter {k}")
    #             is_changed = True

    #     if is_changed or self.hparams.bc:
    #         checkpoint.pop("optimizer_states", None)

    def training_step(self, batch: Any, batch_idx: int):
        obs, actions, rtg, ts, rewards = batch
        targets = None if self.hparams.unsupervise else actions

        if self.hparams.rand_mask_size < 0:
            rand_mask_size = max(1, int(self.hparams.context_length * (self.current_epoch + 1) / self.hparams.epochs))
            self.log(f"train/mask_size", rand_mask_size, on_step=False, on_epoch=True, prog_bar=False)
        else:
            rand_mask_size = self.hparams.rand_mask_size

        if self.hparams.mask_obs_size < 0:
            mask_obs_size = max(
                1, int(self.hparams.context_length * 0.5 * (self.current_epoch + 1) / self.hparams.epochs)
            )
            self.log(f"train/mask_obs_size", mask_obs_size, on_step=False, on_epoch=True, prog_bar=False)
        else:
            mask_obs_size = self.hparams.mask_obs_size

        logits, all_losses = self.net(
            obs,
            actions,
            targets,
            rtg,
            ts,
            rewards,
            pred_forward=self.hparams.forward,
            pred_inverse=self.hparams.inverse,
            pred_reward=self.hparams.reward,
            pred_rand_inverse=self.hparams.rand_inverse,
            rand_mask_size=rand_mask_size,
            mask_obs_size=mask_obs_size,
            forward_weight=self.hparams.forward_weight,
            rand_attn_only=self.hparams.rand_attn_only,
        )

        avg_loss = 0
        for name, loss in all_losses.items():
            self.log(f"train/{name}", loss, on_step=True, on_epoch=True, prog_bar=False)
            avg_loss += loss
        avg_loss /= len(all_losses.keys())
        # log train metrics
        self.log("train/avg_loss", avg_loss, on_step=False, on_epoch=True, prog_bar=False)

        # we can return here dict with any tensors
        # and then read it in some callback or in `training_epoch_end()` below
        # remember to always return loss from `training_step()` or else backpropagation will fail!
        return {"loss": avg_loss, "logits": logits}

    def validation_step(self, batch: Any, batch_idx: int):
        obs, actions, rtg, ts, rewards = batch
        targets = None if self.hparams.unsupervise else actions

        if self.hparams.rand_mask_size < 0:
            rand_mask_size = max(1, int(self.hparams.context_length * (self.current_epoch + 1) / self.hparams.epochs))
            self.log(f"train/mask_size", rand_mask_size, on_step=False, on_epoch=True, prog_bar=False)
        else:
            rand_mask_size = self.hparams.rand_mask_size

        if self.hparams.mask_obs_size < 0:
            mask_obs_size = max(
                1, int(self.hparams.context_length * 0.5 * (self.current_epoch + 1) / self.hparams.epochs)
            )
            self.log(f"train/mask_obs_size", mask_obs_size, on_step=False, on_epoch=True, prog_bar=False)
        else:
            mask_obs_size = self.hparams.mask_obs_size

        logits, all_losses = self.net(
            obs,
            actions,
            targets,
            rtg,
            ts,
            rewards,
            pred_forward=self.hparams.forward,
            pred_inverse=self.hparams.inverse,
            pred_reward=self.hparams.reward,
            pred_rand_inverse=self.hparams.rand_inverse,
            rand_mask_size=rand_mask_size,
            mask_obs_size=mask_obs_size,
            rand_attn_only=self.hparams.rand_attn_only,
        )

        avg_loss = 0
        for name, loss in all_losses.items():
            self.log(f"val/{name}", loss, on_step=False, on_epoch=True, prog_bar=False)
            avg_loss += loss

        avg_loss /= len(all_losses.keys())

        # log train metrics
        self.log("val/avg_loss", avg_loss, on_step=False, on_epoch=True, prog_bar=False)

        return {"loss": avg_loss, "logits": logits}

    def validation_epoch_end(self, outputs: List[Any]):
        if self.hparams.bc:
            eval_return, _ = self.get_return(self.hparams.eval_epochs)
            self.log("val/interactive_reward", eval_return, on_step=False, on_epoch=True, prog_bar=False)
            # return_fn = rank_zero_only(self.get_return)
            # res = return_fn(self.hparams.eval_epochs)
            # if res is not None:
            #     eval_return, _ = res
            #     self.log("val/interactive_reward", eval_return, on_step=False, on_epoch=True, prog_bar=False)

    def test_step(self, batch: Any, batch_idx: int):
        # return_fn = rank_zero_only(self.get_return)
        # res = return_fn(self.hparams.eval_epochs)
        # if res is not None:
        #     eval_return, std_return = res
        #     self.log("test/interactive_reward", eval_return, on_step=False, on_epoch=True)
        #     self.log("test/std", std_return, on_step=False, on_epoch=True)
        eval_return, std_return = self.get_return(self.hparams.eval_epochs)
        self.log("test/interactive_reward", eval_return, on_step=False, on_epoch=True)
        self.log("test/std", std_return, on_step=False, on_epoch=True)
        return {"loss": eval_return}

    def get_return_dmc(self, epochs):
        import dmc2gym

        env = dmc2gym.make(
            domain_name=self.hparams.domain,
            task_name=self.hparams.task,
            visualize_reward=False,
            from_pixels=True,
            height=84,
            width=84,
            frame_skip=4,
            version=2,
        )
        env.seed(self.hparams.seed)

        if self.hparams.model_type == "reward_conditioned":
            ret = get_exp_return_dmc(self.hparams.domain, self.hparams.task)
        else:
            ret = 0

        T_rewards = []
        done = True
        for i in range(epochs):
            state = env.reset()
            state = torch.from_numpy(state).type(torch.float32).to(self.device).div_(255).unsqueeze(0).unsqueeze(0)
            # print("test obs", state.size())
            # print(state)
            # print(obs)
            rtgs = [ret]
            # first state is from env, first rtg is target return, and first timestep is 0
            sampled_action = self.sample(
                state,
                1,
                actions=None,
                rtgs=torch.tensor(rtgs, dtype=torch.float32).to(self.device).unsqueeze(0).unsqueeze(-1),
                timesteps=torch.zeros((1, 1, 1), dtype=torch.int64).to(self.device),
            )

            j = 0
            all_states = state
            actions = []
            while True:
                if done:
                    state, reward_sum, done = env.reset(), 0, False
                action = sampled_action.cpu().numpy()[0]
                action = np.clip(action, -1, 1).astype(np.float32)
                actions += [sampled_action]
                state, reward, done, _ = env.step(action)
                reward_sum += reward
                j += 1

                if done:
                    T_rewards.append(reward_sum)
                    break

                state = torch.from_numpy(state).type(torch.float32).to(self.device).div_(255).unsqueeze(0).unsqueeze(0)

                all_states = torch.cat([all_states, state], dim=1)

                rtgs += [rtgs[-1] - reward]
                # all_states has all previous states and rtgs has all previous rtgs (will be cut to block_size in utils.sample)
                # timestep is just current timestep
                past_actions = torch.cat(actions, dim=0)
                sampled_action = self.sample(
                    all_states,
                    1,
                    actions=past_actions.to(self.device).unsqueeze(0),
                    rtgs=torch.tensor(rtgs, dtype=torch.float32).to(self.device).unsqueeze(0).unsqueeze(-1),
                    timesteps=(
                        min(j, self.hparams.timestep) * torch.ones((1, 1, 1), dtype=torch.int64).to(self.device)
                    ),
                )
            print("episode", i, action, reward_sum)
        env.close()
        T_rewards = np.array(T_rewards)
        eval_return = T_rewards.mean()
        std_return = T_rewards.std()
        print("target return: %d, eval return: %.1f +- %.1f" % (ret, eval_return, std_return))

        return eval_return, std_return

    def get_return(self, epochs):
        if self.hparams.dmc:
            return self.get_return_dmc(epochs)
        

    @torch.no_grad()
    def sample(self, x, steps, temperature=1.0, sample=False, top_k=None, actions=None, rtgs=None, timesteps=None):
        cont_length = self.hparams.context_length
        for k in range(steps):
            x_cond = x if x.size(1) <= cont_length else x[:, -cont_length:]  # crop context if needed
            if actions is not None:
                actions = (
                    actions if actions.size(1) <= cont_length else actions[:, -cont_length:]
                )  # crop context if needed
            rtgs = rtgs if rtgs.size(1) <= cont_length else rtgs[:, -cont_length:]  # crop context if needed
            logits, _ = self.net(x_cond, actions=actions, targets=None, rtgs=rtgs, timesteps=timesteps)
            if self.hparams.dmc:
                x = logits[:, -1, :]
            else:
                # pluck the logits at the final step and scale by temperature
                logits = logits[:, -1, :] / temperature
                # optionally crop probabilities to only the top k options
                if top_k is not None:
                    logits = top_k_logits(logits, top_k)
                # apply softmax to convert to probabilities
                probs = F.softmax(logits, dim=-1)
                # sample from the distribution or take the most likely
                if sample:
                    ix = torch.multinomial(probs, num_samples=1)
                else:
                    _, ix = torch.topk(probs, k=1, dim=-1)
                # append to the sequence and continue
                # x = torch.cat((x, ix), dim=1)
                x = ix

        return x

    def test_epoch_end(self, outputs: List[Any]):
        pass

    def configure_optimizers(self):
        if self.hparams.freeze_encoder:
            return self.net.configure_naive_optimizer(self.hparams)
        else:
            return self.net.configure_optimizers(self.hparams)

    def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True):
        if self.hparams.agent_type == "initgpt":
            state_dict.pop("net.pos_emb")
            state_dict.pop("net.mask")
            state_dict.pop("net.inverse_mask")
            strict = False
        return super().load_state_dict(state_dict, strict)
