from utils.algorithm import ActionInfo, Mode, Config, Params, ReportInfo, Info
from torch import nn
import torch
from utils.algorithm import Algorithm
from typing import List, Tuple, Any, Optional, Callable, Dict, cast
import numpy as np
from utils.common import get_device
from agent.bert import BertTransformer
from torch import optim
from tqdm import tqdm
from pathlib import Path
from agent.dataprocess import transform, normalize, extract_observation, sample
from os import path
from mlpmodel import MLPModel
from sklearn.cluster import KMeans
from itertools import product
import joblib
from utils.focal_loss import FocalLoss
import math
from collections import defaultdict
from robert.agent.commands import COMMANDS
from collections import deque
from args import DEVICE3
from mlpmodel import MLPModel
import time

Observation = torch.Tensor
Action = torch.Tensor

State = Observation
DEVICE = DEVICE3


class Robert(Algorithm):
    def __init__(
        self,
        name: str,
        config: Config,
        hyperparams: Optional[Params] = None,
    ):
        super(Robert, self).__init__(name, config, hyperparams)
        self.state_dim = self.p["state_dim"]
        self.act_dim = self.p["action_dim"]
        self.lr = self.p["lr"]
        self.mini_batch_size = self.p["batch_size"]
        self.pre_steps = self.p["pre_steps"]
        self.post_steps = self.p["post_steps"]
        self.state_seq_len = self.pre_steps + self.post_steps + 1
        self.action_seq_len = self.state_seq_len - 1
        self.train_iters = self.p["train_iters"]
        self.logdir = self.p["log_dir"]
        self.seed = self.p["seed"]
        self.num_centers = self.p["num_centers"]
        self.warmup_steps = self.p["warmup_steps"]
        self.action_slice = slice(None)
        self.action_slice_len = self.action_seq_len

        self.state_seq_mean = self.p["state_seq_mean"]
        self.state_seq_std = self.p["state_seq_std"]
        self.action_mean = self.p["action_mean"]
        self.action_std = self.p["action_std"]
        self.decay_begins = self.p["decay_begins"]
        self.decay_ends = self.p["decay_ends"]
        self.no_mask = self.p["no_mask"]
        self.model_kind = self.p["model_kind"]
        self.mask_begin = self.p["mask_begin"]
        self.mask_end = self.p["mask_end"]
        self.eval_ratio = self.p["eval_ratio"]

        self.kmeans = KMeans(n_clusters=self.num_centers, random_state=self.seed)
        self.trfm = (
            BertTransformer(self.state_dim, self.act_dim)
            if self.model_kind == "transformer"
            else MLPModel(
                self.state_dim, self.act_dim * 2, self.act_dim, self.state_seq_len
            )
        )
        print(
            f"num. of trainable parameters is: {sum([p.numel() for p in self.trfm.parameters() if p.requires_grad]) / int(1e6)}M"
        )
        self.n_params = sum(
            [p.numel() for p in self.trfm.parameters() if p.requires_grad]
        ) / int(1e6)
        self.opt = optim.AdamW(
            self.trfm.parameters(),
            lr=self.lr,
            weight_decay=0.1,
            betas=((0.9, 0.95)),
        )

        self.zeros_states = torch.zeros(
            (1, self.pre_steps + self.post_steps + 1, self.state_dim),
            device=DEVICE,
            dtype=torch.float32,
        )

        def _schedule(t):
            # http://tpcg.io/_M9QZZM
            if t == self.warmup_steps:
                print("stop warmup")
            if t <= self.warmup_steps:
                return (1 - 1e-4) * t / self.warmup_steps + 1e-4

            if self.warmup_steps < t < self.decay_begins:
                return 1

            if self.decay_begins <= t <= self.decay_ends:
                _t = t - self.decay_begins
                _lr = 1e-3 + 0.5 * (1 - 1e-3) * (
                    1 + math.cos(_t / (self.decay_ends - self.decay_begins) * math.pi)
                )
                return _lr

            return 1e-3

        self.shdler = torch.optim.lr_scheduler.LambdaLR(
            self.opt,
            _schedule,
        )
        self.center_loss = FocalLoss(reduction="mean")
        self.offset_loss = nn.MSELoss()

        self._kmeans(self.p["train_raw_actions"])
        self.p["train_raw_actions"] = None

        self.best_grades = 999
        self.mask_tensor = (
            torch.ones((self.state_dim,), dtype=torch.float32, device=DEVICE) * -999
        )

    @torch.no_grad()
    def take_action(self, mode: Mode, state: State, env, mask=None) -> Action:
        assert mode == "eval"
        self.trfm.eval()
        state = state.to(DEVICE)
        assert len(state.shape) == 3
        assert state.shape[1:] == (self.state_seq_len, self.state_dim)

        act_probs, act_offsets = self.trfm(
            normalize(transform(state), self.state_seq_mean, self.state_seq_std),
            masks=mask
            # transform(state)
        )

        act_probs = act_probs[:, self.pre_steps]
        act_offsets = act_offsets[:, self.pre_steps]

        # print(f"act probs are: {act_probs.exp().round(decimals=2)}")
        # print(
        #     f"recovered all actions are: {self.action_centers + act_offsets.squeeze()}"
        # )

        # numpy's version
        # _all_acts = (self.action_centers + act_offsets.squeeze()).numpy(force=True)
        # _weights = act_probs.exp().squeeze().numpy(force=True)
        # weighted_mean = np.average(_all_acts, weights=_weights, axis=0)
        # variance = np.average(
        #     (_all_acts - weighted_mean) ** 2, weights=_weights, axis=0
        # )
        # weighted_std = np.sqrt(variance * (self.num_centers / (self.num_centers - 1)))
        # print(f"weighted std is: {weighted_std}, sum is: {np.sum(weighted_std)}")

        # manual version
        # _all_acts = self.action_centers + act_offsets.squeeze()  # .numpy(force=True)
        # _weights = act_probs.exp().squeeze()  # .numpy(force=True)
        # weighted_mean = (_all_acts * _weights.unsqueeze(1)).mean(dim=0)
        # variance = (
        #     (_weights.unsqueeze(1) * (_all_acts - weighted_mean) ** 2).sum(dim=0)
        #     / ((self.num_centers - 1) / self.num_centers)
        # ).sqrt()
        # print(f"weighted std is: {variance}, sum is: {variance.sum()}")
        acts = self._recover_actions(
            action_logits=act_probs, action_offsets=act_offsets
        )
        acts = acts * self.action_std + self.action_mean
        return acts.clip(-1, 1)

    def get_masks(self, p, L):
        if self.no_mask >= 0.99:
            return None, 0

        # http://tpcg.io/_M9QZZM
        def _calc_mask_ratio(p):
            if p <= self.mask_begin:
                return 0
            if p <= self.mask_end:
                return min(
                    1,
                    1 / (self.mask_end - self.mask_begin) * p
                    - (self.mask_begin) / (self.mask_end - self.mask_begin),
                )
            return 1

        mask_ratio = _calc_mask_ratio(p)
        if p <= self.mask_begin:
            return None, 0
        if self.no_mask >= 0.01:
            mask_ratio = self.no_mask

        masks = torch.zeros((self.mini_batch_size, L), dtype=torch.bool, device=DEVICE)
        rng = np.random.default_rng(self.seed)
        to_be_masked = np.ones((self.mini_batch_size, L), dtype=np.int64) * np.arange(
            L, dtype=np.int64
        )
        to_be_masked = torch.as_tensor(
            rng.permuted(to_be_masked, axis=-1), device=DEVICE, dtype=torch.long
        )[:, :-1][:, : math.ceil(mask_ratio * (L - 1))]
        masks[torch.arange(masks.size(0)).unsqueeze(1), to_be_masked] = True
        assert torch.all(masks.sum(dim=-1) == math.ceil(mask_ratio * (L - 1)))
        # print(f"mask ratio is: {int(masks.sum().item() / 1024)}")
        assert masks.dtype == torch.bool
        return masks, masks.sum().item() / self.mini_batch_size / (L - 1)

    def _manual_train(self, info: Dict[str, Any]):
        self.trfm.train()
        states = info["train_states"]
        actions = info["train_actions"]
        progress = info["progress"]

        assert len(states.shape) == len(actions.shape) == 4
        state_size, state_seq_idx, state_seq, state_dim = states.shape
        action_size, action_seq_idx, action_seq, action_dim = actions.shape

        assert state_size == action_size
        assert state_seq_idx == action_seq_idx
        assert state_seq == action_seq + 1
        assert state_dim == self.state_dim
        assert action_dim == self.act_dim

        sampled_states, sampled_actions = sample(states, actions, self.mini_batch_size)
        sampled_states = sampled_states.to(DEVICE)
        sampled_actions = sampled_actions.to(DEVICE)
        assert sampled_states.shape == (self.mini_batch_size, state_seq, state_dim)
        assert sampled_actions.shape == (self.mini_batch_size, action_seq, action_dim)

        masks, mask_ratio = self.get_masks(progress, self.post_steps)

        if masks is not None:
            assert masks.shape == (self.mini_batch_size, self.post_steps)
            masks = torch.cat(
                (
                    torch.zeros(
                        (self.mini_batch_size, 1 + self.pre_steps),
                        dtype=torch.bool,
                        device=DEVICE,
                    ),
                    masks,
                ),
                dim=-1,
            )

        sampled_states = normalize(
            transform(sampled_states), self.state_seq_mean, self.state_seq_std
        )
        # sampled_states = transform(sampled_states)
        sampled_actions = sampled_actions[:, self.action_slice].reshape(
            (-1, self.act_dim)
        )

        act_ctr_idx = torch.as_tensor(
            self.kmeans.predict(sampled_actions.numpy(force=True)),
            dtype=torch.long,
            device=DEVICE,
        )
        assert act_ctr_idx.shape == (self.mini_batch_size * self.action_slice_len,)

        act_offset = sampled_actions - self.action_centers[act_ctr_idx]
        assert act_offset.shape == (
            self.mini_batch_size * self.action_slice_len,
            self.act_dim,
        )

        pred_act_ctrs_logits, pred_act_offsets = self.trfm(sampled_states, masks)
        pred_act_ctrs_logits = pred_act_ctrs_logits[:, self.action_slice].reshape(
            (-1, self.num_centers)
        )
        pred_act_offsets = pred_act_offsets[:, self.action_slice].reshape(
            (-1, self.num_centers, self.act_dim)
        )

        assert pred_act_ctrs_logits.shape == (
            self.mini_batch_size * self.action_slice_len,
            self.num_centers,
        )
        assert pred_act_offsets.shape == (
            self.mini_batch_size * self.action_slice_len,
            self.num_centers,
            self.act_dim,
        )

        recovered_actions = self._recover_actions(
            pred_act_ctrs_logits, pred_act_offsets
        )

        pred_act_offsets = pred_act_offsets[
            torch.arange(self.mini_batch_size * self.action_slice_len), act_ctr_idx
        ]

        center_loss = self.center_loss(pred_act_ctrs_logits, act_ctr_idx.detach())
        offset_loss = self.offset_loss(pred_act_offsets, act_offset.detach())

        self.opt.zero_grad()
        (center_loss + 5 * offset_loss).backward()
        torch.nn.utils.clip_grad_norm_(self.trfm.parameters(), 0.75)
        self.opt.step()
        self.shdler.step()

        self.add_scalars(
            dict(
                center_loss=center_loss.item(),
                offset_loss=offset_loss.item(),
                mask_ratio=mask_ratio,
                mae_loss=(recovered_actions - sampled_actions).abs().mean().item(),
            ),
            "train",
        )

        self.trained_steps += self.mini_batch_size
        self.trfm.eval()

    def _recover_actions(
        self, action_logits: torch.Tensor, action_offsets: torch.Tensor
    ):
        action_logits = action_logits.detach()
        action_offsets = action_offsets.detach()
        assert len(action_logits.shape) == 2
        assert len(action_offsets.shape) == 3
        assert action_logits.shape[-1:] == (self.num_centers,)
        assert action_offsets.shape[-2:] == (self.num_centers, self.act_dim)

        action_probs = action_logits.exp()
        assert torch.all((1 - action_probs.sum(dim=-1)).abs() <= 1e-4)
        dist = torch.distributions.Categorical(probs=action_probs)
        sampled_act_idx = dist.sample()

        recovered_actions = (
            self.action_centers[sampled_act_idx]
            + action_offsets[torch.arange(action_offsets.size(0)), sampled_act_idx]
        )

        return recovered_actions

    def manual_train(self, info: Dict[str, Any]):
        self._manual_train(info)

    def pretrain(self, info: Info):
        self._save_once(f"{self.logdir}/save_once")
        print("pretrain finish")

    @torch.no_grad()
    def eval(self, info):
        start_time = time.time()
        env = info["env"]
        action_spec = env.action_spec()

        no_masked_commands = info["no_masked_commands"]
        half_masked_commands = info["half_masked_commands"]
        full_masked_commands = info["full_masked_commands"]

        section_grades = {"no-mask": [], "full-mask": [], "half-mask": []}
        # for _all_commands in COMMANDS:  # [
        # ("normal", COMMANDS),
        # ("half", half_masked_commands),
        # ("full", full_masked_commands),
        # ]:
        # section_name = section[0]
        # _all_commands = section[1]
        ALL_COMMANDS = {
            **COMMANDS,
            **no_masked_commands,
            **half_masked_commands,
            **full_masked_commands,
        }

        for name, _command in ALL_COMMANDS.items():
            time_step = env.reset()
            obs_history = []
            act_history = []
            section_name = (
                (name[cast(str, name).index("-") + 1 :]) if "-" in name else ""
            )
            assert section_name in ["full-mask", "half-mask", "no-mask"]

            if not isinstance(_command, tuple):
                command = _command[
                    : (self.post_steps * self.eval_ratio + self.post_steps - 1)
                ]
                mask_idx = None
                masked_command = command
                masks = None
            else:
                command = _command[0][
                    : (self.post_steps * self.eval_ratio + self.post_steps - 1)
                ]
                mask_idx = _command[1]
                mask_idx = mask_idx[
                    mask_idx < (self.post_steps * self.eval_ratio + self.post_steps - 1)
                ]
                assert mask_idx.dtype == torch.int64 and len(mask_idx.shape) == 1
                masks = torch.ones(
                    (self.post_steps * self.eval_ratio + self.post_steps - 1,),
                    dtype=torch.bool,
                )

                masks[mask_idx] = False
                masked_command = command.clone()
                # masked_command[masks] = torch.tensor(
                #     [-999, -999], dtype=torch.float32, device=DEVICE
                # )
                masked_command[masks] = self.mask_tensor
                # print(f"masked command is: {masked_command.tolist()}")

            coord_commands = (
                torch.as_tensor(
                    extract_observation(time_step),
                    dtype=torch.float32,
                    device=DEVICE,
                )
                + command
            )
            if masks is not None:
                masked_coord_commands = coord_commands.clone()
                # masked_coord_commands[masks] = torch.tensor(
                #     [-999, -999], dtype=torch.float32, device=DEVICE
                # )
                masked_coord_commands[masks] = self.mask_tensor
            else:
                masked_coord_commands = coord_commands

            assert command.shape == (
                self.post_steps * self.eval_ratio + self.post_steps - 1,
                self.state_dim,
            )

            def act(t):
                nonlocal obs_history, act_history, masks
                if len(obs_history) < self.pre_steps:
                    action = np.zeros(action_spec.shape)
                else:
                    history = torch.stack(obs_history[-self.pre_steps :], dim=0)
                    assert len(history) == self.pre_steps
                    current = torch.as_tensor(
                        extract_observation(t),  # .observation[:2],
                        dtype=torch.float32,
                        device=DEVICE,
                    ).unsqueeze(0)

                    input_corr = masked_coord_commands[
                        (len(obs_history) - self.pre_steps) : (
                            len(obs_history) - self.pre_steps + self.post_steps
                        )
                    ]

                    assert input_corr.shape == (self.post_steps, self.state_dim)

                    if masks is not None:
                        _masks = torch.cat(
                            (
                                torch.zeros(
                                    (self.pre_steps + 1,),
                                    dtype=torch.bool,
                                    # device=DEVICE,
                                ),
                                masks[
                                    (len(obs_history) - self.pre_steps) : (
                                        len(obs_history)
                                        - self.pre_steps
                                        + self.post_steps
                                    )
                                ],
                            ),
                            dim=0,
                        )
                        _masks = _masks.unsqueeze(0)
                        assert _masks.shape == (1, self.state_seq_len), _masks.shape
                    else:
                        _masks = None

                    action = self.take_action(
                        "eval",
                        torch.cat(
                            (history, current, input_corr),
                            dim=0,
                        ).unsqueeze(0),
                        env,
                        mask=_masks,
                    ).numpy(force=True)[0]

                act_history.append(action)
                return action

            while len(obs_history) < self.pre_steps + self.eval_ratio * self.post_steps:
                assert not time_step.last()

                old_timestep = time_step
                a = act(time_step)
                time_step = env.step(a)
                obs_history.append(
                    torch.as_tensor(
                        extract_observation(old_timestep),  # .observation[:2],
                        dtype=torch.float32,
                        device=DEVICE,
                    )
                )
            obs_history.append(
                torch.as_tensor(
                    extract_observation(time_step),  # .observation[:2],
                    dtype=torch.float32,
                    device=DEVICE,
                )
            )

            assert (
                len(obs_history)
                == self.pre_steps + self.post_steps * self.eval_ratio + 1
            )

            achieved = torch.stack(obs_history[-self.eval_ratio * self.post_steps :])
            assert len(achieved) == self.post_steps * self.eval_ratio
            if masks is not None:
                masked_achieved = (
                    achieved[
                        torch.logical_not(masks[: self.post_steps * self.eval_ratio])
                    ]
                    - coord_commands[: self.post_steps * self.eval_ratio][
                        torch.logical_not(masks[: self.post_steps * self.eval_ratio])
                    ]
                )

            else:
                masked_achieved = (
                    achieved - coord_commands[: self.post_steps * self.eval_ratio]
                )

            score = masked_achieved.abs().mean().item()
            self.add_scalars(
                {name: score},
                "eval",
            )
            section_grades[section_name].append(score)

        assert sum([len(sg) for sg in section_grades.values()]) == len(ALL_COMMANDS)
        self.add_scalars(
            # {section_name: np.mean(section_grades)},
            {k: np.mean(v) for k, v in section_grades.items()},
            "eval/section",
        )
        # total_grades.append([np.mean(sg) for sg in section_grades.values()])

        # assert len(total_grades) == 3

        final_grades = np.mean([np.mean(sg) for sg in section_grades.values()])

        if final_grades < self.best_grades:
            self.best_grades = final_grades

        self.add_scalars(
            {"grades": final_grades, "best_grades": self.best_grades}, "eval/summary"
        )
        print("--- eval finished: %s seconds ---" % (time.time() - start_time))

    def _kmeans(self, actions):
        all_actions = actions.reshape((-1, self.act_dim))
        action_length = all_actions.size(0)
        print(f"length of actions in kmeans: {action_length}")
        reduced_action_length = min((int(2e4), action_length))
        self.kmeans.fit(
            all_actions[torch.randperm(action_length)[:reduced_action_length]].numpy(
                force=True
            )
        )

        self.action_centers = torch.as_tensor(
            self.kmeans.cluster_centers_, dtype=torch.float32, device=DEVICE
        )
        assert self.action_centers.shape == (self.num_centers, self.act_dim)

    def _save_once(self, dir: str):
        Path(dir).mkdir(exist_ok=True, parents=True)
        torch.save(self.state_seq_mean, f"{dir}/state_seq_mean.pth")
        torch.save(self.state_seq_std, f"{dir}/state_seq_std.pth")
        torch.save(self.action_centers, f"{dir}/action_centers.pth")
        torch.save(self.action_mean, f"{dir}/action_mean.pt")
        torch.save(self.action_std, f"{dir}/action_std.pt")
        joblib.dump(self.kmeans, f"{dir}/kmeans.joblib")
        print(f"success save_once to {dir}")
        self.add_text(
            "num. of params.",
            str(sum([p.numel() for p in self.trfm.parameters()]) / int(1e6)) + "M",
        )

    def save(self, times: int):
        dir = path.join(self.logdir, str(times))
        Path(dir).mkdir(exist_ok=True, parents=True)
        torch.save(self.trfm.state_dict(), f"{dir}/bert_transformer.pth")
        print(f"success save to {dir}")

    def load(self, dir: str, map_location):
        self.trfm.load_state_dict(
            torch.load(f"{dir}/bert_transformer.pth", map_location)
        )
        self.state_seq_mean = torch.load(f"{dir}/state_seq_mean.pth", map_location)
        self.state_seq_std = torch.load(f"{dir}/state_seq_std.pth", map_location)
        self.action_centers = torch.load(f"{dir}/action_centers.pth", map_location)
        self.action_mean = torch.load(f"{dir}/action_mean.pt", map_location)
        self.action_std = torch.load(f"{dir}/action_std.pt", map_location)
        self.kmeans = joblib.load(f"{dir}/kmeans.joblib")
