from collections import OrderedDict
from typing import Any, List

import numpy as np
import torch
from pytorch_lightning import LightningModule
from torch.nn import functional as F


from models.components.mingpt import GPT, GPTConfig 
from models.utils import (
    get_exp_return_dmc,
    get_min_action_dmc,
    top_k_logits,
)


class MultiTaskDTLitModule(LightningModule):
    """Example of LightningModule for MNIST classification.

    A LightningModule organizes your PyTorch code into 5 sections:
        - Computations (init).
        - Train loop (training_step)
        - Validation loop (validation_step)
        - Test loop (test_step)
        - Optimizers (configure_optimizers)

    Read the docs:
        https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
    """

    def __init__(
        self,
        game,
        dmc,
        domain,
        task,
        agent_type,
        model_type,
        timestep=10000,
        n_embd=128,
        lr=6e-4,
        forward=False,
        inverse=False,
        reward=False,
        rand_inverse=False,
        rand_mask_size=1,
        freeze_encoder=False,
        patch_size=14,
        context_length=30,
        betas=(0.9, 0.95),
        weight_decay=0.1,
        n_layer=6,
        n_head=8,
        pred_layers=1,
        stack_size=4,
        **kwargs,
    ):
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # it also ensures init params will be stored in ckpt
        self.save_hyperparameters()
        if self.hparams.dmc:
            channels = 3
            if hasattr(self.hparams, "source_envs"):
                self.domains = self.hparams.source_envs.keys()
                vocab_size = 0
                for domain in self.domains:
                    vocab_size += get_min_action_dmc(domain)
            else:
                vocab_size = get_min_action_dmc(self.hparams.domain)
        else:
            raise NotImplementedError

        if self.hparams.agent_type == "gpt":
            block_size = self.hparams.context_length * 2
            mconf = GPTConfig(
                vocab_size,
                block_size,
                max_timestep=self.hparams.timestep,
                channels=channels,
                model_type=self.hparams.model_type,
                n_layer=self.hparams.n_layer,
                n_head=self.hparams.n_head,
                n_embd=self.hparams.n_embd,
                cont_action=self.hparams.dmc,
                pred_layers=self.hparams.pred_layers,
                rtg_layers=self.hparams.rtg_layers,
                bc_layers=self.hparams.bc_layers,
            )
            self.net = GPT(mconf)
            print(self.net)
        else:
            assert "agent type not supported"

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("DTModel")

        parser.add_argument("--agent_type", type=str, default="gpt")
        parser.add_argument(
            "--model_type", type=str, default="reward_conditioned", choices=["reward_conditioned", "naive"]
        )
        parser.add_argument("--n_embd", type=int, default=128)
        parser.add_argument("--lr", type=float, default=6e-4)
        parser.add_argument("--unsupervise", default=False, action="store_true")
        parser.add_argument("--forward", default=False, action="store_true")
        parser.add_argument("--inverse", default=False, action="store_true")
        parser.add_argument("--reward", default=False, action="store_true")
        parser.add_argument("--rand_inverse", default=False, action="store_true")
        parser.add_argument("--freeze_encoder", default=False, action="store_true")
        parser.add_argument("--rand_mask_size", type=int, default=30)
        parser.add_argument("--mask_obs_size", type=int, default=0)
        parser.add_argument("--n_layer", type=int, default=6)
        parser.add_argument("--n_head", type=int, default=8)

        # weights
        parser.add_argument("--forward_weight", type=float, default=1.0)

        # layers
        parser.add_argument("--rtg_layers", type=int, default=1)
        parser.add_argument("--bc_layers", type=int, default=1)
        parser.add_argument("--pred_layers", type=int, default=1)
        # STGPT
        parser.add_argument("--patch_agg", type=str, default="last")
        parser.add_argument("--patch_size", type=int, default=14)

        # DMC
        parser.add_argument("--dmc", default=False, action="store_true")

        return parent_parser

    def training_step(self, batch: Any, batch_idx: int):
        obs, actions, rtg, ts, rewards, task_ids = batch
        targets = None if self.hparams.unsupervise else actions

        if self.hparams.rand_mask_size < 0:
            rand_mask_size = max(1, int(self.hparams.context_length * (self.current_epoch + 1) / self.hparams.epochs))
            # rand_mask_size = min(30, int((self.current_epoch / 3 + 1) * 10))
            self.log(f"train/mask_size", rand_mask_size, on_step=False, on_epoch=True, prog_bar=False)
        else:
            rand_mask_size = self.hparams.rand_mask_size

        if self.hparams.mask_obs_size < 0:
            mask_obs_size = min(
                self.hparams.context_length // 2,
                max(1, int(self.hparams.context_length * (self.current_epoch + 1) / self.hparams.epochs)),
            )
            self.log(f"train/mask_obs_size", mask_obs_size, on_step=False, on_epoch=True, prog_bar=False)
        else:
            mask_obs_size = self.hparams.mask_obs_size

        logits, all_losses = self.net(
            obs,
            actions,
            targets,
            rtg,
            ts,
            rewards,
            pred_forward=self.hparams.forward,
            pred_inverse=self.hparams.inverse,
            pred_reward=self.hparams.reward,
            pred_rand_inverse=self.hparams.rand_inverse,
            rand_mask_size=rand_mask_size,
            mask_obs_size=mask_obs_size,
            forward_weight=self.hparams.forward_weight,
        )

        avg_loss = 0
        for name, loss in all_losses.items():
            self.log(f"train/{name}", loss, on_step=True, on_epoch=True, prog_bar=False)
            avg_loss += loss
        avg_loss /= len(all_losses.keys())
        # log train metrics
        self.log("train/avg_loss", avg_loss, on_step=False, on_epoch=True, prog_bar=False)

        # we can return here dict with any tensors
        # and then read it in some callback or in `training_epoch_end()` below
        # remember to always return loss from `training_step()` or else backpropagation will fail!
        return {"loss": avg_loss, "logits": logits}

    def validation_step(self, batch: Any, batch_idx: int):
        obs, actions, rtg, ts, rewards, task_ids = batch
        targets = None if self.hparams.unsupervise else actions

        if self.hparams.rand_mask_size < 0:
            rand_mask_size = max(1, int(self.hparams.context_length * (self.current_epoch + 1) / self.hparams.epochs))
            # rand_mask_size = min(30, int((self.current_epoch / 3 + 1) * 10))
            self.log(f"train/mask_size", rand_mask_size, on_step=False, on_epoch=True, prog_bar=False)
        else:
            rand_mask_size = self.hparams.rand_mask_size

        if self.hparams.mask_obs_size < 0:
            mask_obs_size = min(
                self.hparams.context_length // 2,
                max(1, int(self.hparams.context_length * (self.current_epoch + 1) / self.hparams.epochs)),
            )
            self.log(f"train/mask_obs_size", mask_obs_size, on_step=False, on_epoch=True, prog_bar=False)
        else:
            mask_obs_size = self.hparams.mask_obs_size

        logits, all_losses = self.net(
            obs,
            actions,
            targets,
            rtg,
            ts,
            rewards,
            pred_forward=self.hparams.forward,
            pred_inverse=self.hparams.inverse,
            pred_reward=self.hparams.reward,
            pred_rand_inverse=self.hparams.rand_inverse,
            rand_mask_size=rand_mask_size,
            mask_obs_size=mask_obs_size,
        )

        avg_loss = 0
        for name, loss in all_losses.items():
            self.log(f"val/{name}", loss, on_step=True, on_epoch=True, prog_bar=False)
            avg_loss += loss

        avg_loss /= len(all_losses.keys())

        # log train metrics
        self.log("val/avg_loss", avg_loss, on_step=False, on_epoch=True, prog_bar=False)

        return {"loss": avg_loss, "logits": logits}

    def configure_optimizers(self):
        if self.hparams.freeze_encoder:
            return self.net.configure_naive_optimizer(self.hparams)
        else:
            return self.net.configure_optimizers(self.hparams)

    def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True):
        if self.hparams.agent_type == "initgpt":
            state_dict.pop("net.pos_emb")
            state_dict.pop("net.mask")
            state_dict.pop("net.inverse_mask")
            strict = False
        return super().load_state_dict(state_dict, strict)
