import itertools
import time

from pathlib import Path
from collections import defaultdict
from typing import (
    Any,
    Dict,
)

import numpy as np
import torch
import torch.nn.functional as F

from tqdm.auto import tqdm

from behavioral_cloning.utils import timeit


class Trainer:
    def __init__(
        self,
        model,
        optimizer,
        dataset,
        logger,
        device="cpu",
        model_checkpoints=Path("checkpoints"),
        scheduler=None,
        eval_fns=None,
        clip_grad_norm=0.25,
    ):
        super().__init__()
        self.logger = logger
        self.model = model
        self.optimizer = optimizer
        self.dataset = dataset
        self.device = device
        self.model_checkpoints = model_checkpoints
        self.scheduler = scheduler
        self.eval_fns = [] if eval_fns is None else eval_fns
        self.clip_grad_norm = clip_grad_norm
        self.log_time = defaultdict(list)

        self.start_time = time.time()
        self.total_it = 0
        self.log_it = 0

    def train(
        self,
        num_epochs: int = 40,
        num_train_steps: int = 10000,
        log_every: int = 1,
        save_every: int = 10,
    ):
        self.eval_epoch()
        self._log_after_epoch()

        pbar = tqdm(total=num_train_steps * log_every)
        for epoch in range(num_epochs):
            self.total_it += num_train_steps

            self.train_epoch(num_train_steps, pbar)

            if (epoch + 1) % log_every == 0:
                self.eval_epoch()
                self._log_after_epoch()

                pbar.refresh()
                pbar.reset()

            if (epoch + 1) % save_every == 0:
                self.model_checkpoints.mkdir(parents=True, exist_ok=True)
                torch.save(
                    self.state_dict(),
                    self.model_checkpoints / f"checkpoint_{epoch + 1}.pt",
                )
        pbar.close()

    @timeit
    def train_epoch(self, num_steps, pbar):
        self.training_mode(True)
        for _ in range(num_steps):
            pbar.update()
            batch = self.dataset.sample_batch()
            batch = {k: v.to(self.device) for k, v in batch.items()}

            train_logs = self.train_step(batch)
            self.logger.store({f"training/{k}": v for k, v in train_logs.items()})

            if self.scheduler is not None:
                self.scheduler.step()

    @timeit
    def eval_epoch(self):
        eval_logs = {}

        self.training_mode(False)
        for eval_fn in self.eval_fns:
            outputs = eval_fn(self.get_model())
            eval_logs.update(**outputs)

        # add mean_success_rate
        success_logs = dict(filter(lambda x: "success_rate" in x[0], eval_logs.items()))
        # group by target_rew
        keyfunc = lambda x: x[0].split("/")[2]
        data = sorted(success_logs.items(), key=keyfunc)
        for k, g in itertools.groupby(data, keyfunc):
            # get mean success rate for each target_rew
            mean_success = np.array([e[1] for e in list(g)]).mean()
            eval_logs[f"all_tasks/{k}/mean_success_rate"] = mean_success

        self.logger.store({f"evaluation/{k}": v for k, v in eval_logs.items()})

    def _log_after_epoch(self):
        self.logger.log_tabular("dataset size", len(self.dataset))
        self.logger.log_tabular("epoch", self.log_it)
        self.logger.log_tabular("total_env_steps", self.total_it)
        self.logger.log_tabular("time/total", time.time() - self.start_time)

        time_logs = {f"time/{k}": v for k, v in self.log_time.items()}
        self.logger.store(time_logs)
        for key in time_logs.keys():
            self.logger.log_tabular(
                key, with_median=True, with_min_and_max=True, with_sum=True
            )
        # reset logs for the medians, means etc to be meaningful
        self.log_time = defaultdict(list)

        self.log_it += 1

        for key in filter(
            lambda k: k.startswith("training") or k.startswith("evaluation"),
            self.logger.epoch_dict.keys(),
        ):
            self.logger.log_tabular(key, average_only=True)

        self.logger.dump_tabular()

    @timeit
    def train_step(self, batch):
        target = batch["actions"]
        input = batch["states"]

        output = self.model(input)
        loss = F.cross_entropy(output, target)

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm)
        self.optimizer.step()

        log_dict = dict(loss=loss.detach().cpu().item())
        return log_dict

    def infinite_iterator(self, generator):
        iterator = iter(generator())
        while True:
            try:
                data = next(iterator)
            except StopIteration:
                iterator = iter(generator())
                data = next(iterator)
            yield data

    def get_model(self):
        return self.model

    @property
    def networks(self):
        base_list = [self.model]
        return base_list

    def training_mode(self, mode):
        for net in self.networks:
            net.train(mode)

    def state_dict(self) -> Dict[str, Any]:
        return {
            "model": self.model.state_dict(),
            "optim": self.optimizer.state_dict(),
            "total_it": self.total_it,
        }

    def load_state_dict(self, state_dict: Dict[str, Any]):
        self.model.load_state_dict(state_dict=state_dict["model"])
        self.optimizer.load_state_dict(state_dict=state_dict["optim"])
        self.total_it = state_dict["total_it"]
