import math
import logging
import time

from tqdm import tqdm
import numpy as np
import json

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataloader import DataLoader

logger = logging.getLogger(__name__)

from mingpt.utils import sample, target_return_mapping, get_normalized_score
import atari_py
from collections import deque
import random
import cv2
import torch
from PIL import Image


class TrainerConfig:
    # optimization parameters
    max_epochs = 10
    batch_size = 64
    learning_rate = 3e-4
    betas = (0.9, 0.95)
    grad_norm_clip = 1.0
    weight_decay = 0.1  # only applied on matmul weights
    # learning rate decay params: linear warmup followed by cosine decay to 10% of original
    lr_decay = False
    warmup_tokens = 375e6  # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
    final_tokens = 260e9  # (at what point we reach 10% of original LR)
    # checkpoint settings
    ckpt_path = None
    num_workers = 0  # for DataLoader
    num_games_to_use_for_eval = 10

    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)


class Trainer:
    def __init__(self, model, train_dataset, test_dataset, config):
        self.model = model
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.config = config
        self.training_times = []
        self.eval_times = []

        # take over whatever gpus are on the system
        self.device = "cpu"
        if torch.cuda.is_available():
            self.device = torch.cuda.current_device()
            self.model = torch.nn.DataParallel(self.model).to(self.device)

    def save_checkpoint(self, raw_model, epoch):
        # DataParallel wrappers keep raw model object in .module attribute
        raw_model = self.model.module if hasattr(self.model, "module") else self.model
        logger.info("saving %s", self.config.ckpt_path)
        torch.save(
            raw_model.state_dict(),
            self.config.ckpt_path + "model_" + str(epoch) + ".pth",
        )

    def train(self):
        model, config = self.model, self.config
        raw_model = model.module if hasattr(self.model, "module") else model
        optimizer = raw_model.configure_optimizers(config)

        def run_epoch(split, epoch_num=0):
            start_time_epoch = time.time()

            is_train = split == "train"
            model.train(is_train)
            data = self.train_dataset if is_train else self.test_dataset
            loader = DataLoader(
                data,
                shuffle=True,
                pin_memory=True,
                batch_size=config.batch_size,
                num_workers=config.num_workers,
            )

            losses = []
            pbar = (
                tqdm(enumerate(loader), total=len(loader))
                if is_train
                else enumerate(loader)
            )
            for it, (x, y, r, t) in pbar:
                # place data on the correct device
                x = x.to(self.device)
                y = y.to(self.device)
                r = r.to(self.device)
                t = t.to(self.device)

                # forward the model
                with torch.set_grad_enabled(is_train):
                    # logits, loss = model(x, y, r)
                    logits, loss = model(x, y, y, r, t)
                    loss = (
                        loss.mean()
                    )  # collapse all losses if they are scattered on multiple gpus
                    losses.append(loss.item())

                if is_train:

                    # backprop and update the parameters
                    model.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), config.grad_norm_clip
                    )
                    optimizer.step()

                    # decay the learning rate based on our progress
                    if config.lr_decay:
                        self.tokens += (
                            y >= 0
                        ).sum()  # number of tokens processed this step (i.e. label is not -100)
                        if self.tokens < config.warmup_tokens:
                            # linear warmup
                            lr_mult = float(self.tokens) / float(
                                max(1, config.warmup_tokens)
                            )
                        else:
                            # cosine learning rate decay
                            progress = float(
                                self.tokens - config.warmup_tokens
                            ) / float(
                                max(1, config.final_tokens - config.warmup_tokens)
                            )
                            lr_mult = max(
                                0.1, 0.5 * (1.0 + math.cos(math.pi * progress))
                            )
                        lr = config.learning_rate * lr_mult
                        for param_group in optimizer.param_groups:
                            param_group["lr"] = lr
                    else:
                        lr = config.learning_rate

                    # report progress
                    pbar.set_description(
                        f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}"
                    )

            if not is_train:
                test_loss = float(np.mean(losses))
                logger.info("test loss: %f", test_loss)
                return test_loss

            self.training_times.append(time.time() - start_time_epoch)

        # best_loss = float('inf')

        self.tokens = 0  # counter used for learning rate decay

        results_dict = {}

        for epoch in range(config.max_epochs):

            run_epoch("train", epoch_num=epoch)
            self.save_checkpoint(model, epoch)

            # -- pass in target returns
            if self.config.model_type == "naive":
                eval_return = self.get_returns(0)
            elif self.config.model_type == "reward_conditioned":
                if self.config.game == "Breakout":
                    eval_return = self.get_returns(90)
                elif self.config.game == "Seaquest":
                    eval_return = self.get_returns(1150)
                elif self.config.game == "Qbert":
                    eval_return = self.get_returns(14000)
                elif self.config.game == "Pong":
                    eval_return = self.get_returns(20)
                else:
                    raise NotImplementedError()
            else:
                raise NotImplementedError()

            normalized_score = get_normalized_score(self.config.game, eval_return)
            print(
                "target return: %d, eval return: %d"
                % (target_return_mapping[self.config.game], eval_return)
            )
            #  print("DT baseline: %d, normalized score: %d" % (365, normalized_score))
            results_dict[epoch] = {
                "eval_return": eval_return,
                "target_return": target_return_mapping[self.config.game],
                "normalized_score": normalized_score,
            }
        results_dict["training_time"] = np.average(self.training_times)
        results_dict["eval_time"] = np.average(self.eval_times)
        results_dict["training_params"] = sum(
            p.numel() for p in model.parameters() if p.requires_grad
        )

        with open(self.config.ckpt_path + "results.json", "w") as fp:
            json.dump(results_dict, fp, indent=2)
        print("Results saved here: ", self.config.ckpt_path)

    def get_returns(self, ret):
        eval_start_time = time.time()

        self.model.train(False)
        args = Args(self.config.game.lower(), self.config.seed)
        env = Env(args)
        env.eval()

        T_rewards, T_Qs = [], []
        done = True
        for i in range(self.config.num_games_to_use_for_eval):
            state = env.reset()
            state = state.type(torch.float32).to(self.device).unsqueeze(0).unsqueeze(0)
            rtgs = [ret]
            # first state is from env, first rtg is target return, and first timestep is 0
            sampled_action = sample(
                self.model.module,
                state,
                1,
                temperature=1.0,
                sample=True,
                actions=None,
                rtgs=torch.tensor(rtgs, dtype=torch.long)
                .to(self.device)
                .unsqueeze(0)
                .unsqueeze(-1),
                timesteps=torch.zeros((1, 1, 1), dtype=torch.int64).to(self.device),
            )

            j = 0
            all_states = state
            actions = []
            while True:
                if done:
                    state, reward_sum, done = env.reset(), 0, False
                action = sampled_action.cpu().numpy()[0, -1]
                actions += [sampled_action]
                state, reward, done = env.step(action)
                reward_sum += reward
                j += 1

                if done:
                    T_rewards.append(reward_sum)
                    break

                state = state.unsqueeze(0).unsqueeze(0).to(self.device)

                all_states = torch.cat([all_states, state], dim=0)

                rtgs += [rtgs[-1] - reward]
                # all_states has all previous states and rtgs has all previous rtgs (will be cut to block_size in utils.sample)
                # timestep is just current timestep
                sampled_action = sample(
                    self.model.module,
                    all_states.unsqueeze(0),
                    1,
                    temperature=1.0,
                    sample=True,
                    actions=torch.tensor(actions, dtype=torch.long)
                    .to(self.device)
                    .unsqueeze(1)
                    .unsqueeze(0),
                    rtgs=torch.tensor(rtgs, dtype=torch.long)
                    .to(self.device)
                    .unsqueeze(0)
                    .unsqueeze(-1),
                    timesteps=(
                        min(j, self.config.max_timestep)
                        * torch.ones((1, 1, 1), dtype=torch.int64).to(self.device)
                    ),
                )
        env.close()
        eval_return = sum(T_rewards) / self.config.num_games_to_use_for_eval
        print("target return: %d, eval return: %d" % (ret, eval_return))
        self.model.train(True)
        self.eval_times.append(time.time() - eval_start_time)
        return eval_return


class Env:
    def __init__(self, args):
        self.device = args.device
        self.ale = atari_py.ALEInterface()
        self.ale.setInt("random_seed", args.seed)
        self.ale.setInt("max_num_frames_per_episode", args.max_episode_length)
        self.ale.setFloat("repeat_action_probability", 0)  # Disable sticky actions
        self.ale.setInt("frame_skip", 0)
        self.ale.setBool("color_averaging", False)
        self.ale.loadROM(
            atari_py.get_game_path(args.game)
        )  # ROM loading must be done after setting options
        actions = self.ale.getMinimalActionSet()
        self.actions = dict([i, e] for i, e in zip(range(len(actions)), actions))
        self.lives = 0  # Life counter (used in DeepMind training)
        self.life_termination = (
            False  # Used to check if resetting only from loss of life
        )
        self.window = args.history_length  # Number of frames to concatenate
        self.state_buffer = deque([], maxlen=args.history_length)
        self.training = True  # Consistent with model training mode

    def _get_state(self):
        state = cv2.resize(
            self.ale.getScreenGrayscale(), (84, 84), interpolation=cv2.INTER_LINEAR
        )
        return torch.tensor(state, dtype=torch.float32, device=self.device).div_(255)

    def _reset_buffer(self):
        for _ in range(self.window):
            self.state_buffer.append(torch.zeros(84, 84, device=self.device))

    def reset(self):
        if self.life_termination:
            self.life_termination = False  # Reset flag
            self.ale.act(0)  # Use a no-op after loss of life
        else:
            # Reset internals
            self._reset_buffer()
            self.ale.reset_game()
            # Perform up to 30 random no-ops before starting
            for _ in range(random.randrange(30)):
                self.ale.act(0)  # Assumes raw action 0 is always no-op
                if self.ale.game_over():
                    self.ale.reset_game()
        # Process and return "initial" state
        observation = self._get_state()
        self.state_buffer.append(observation)
        self.lives = self.ale.lives()
        return torch.stack(list(self.state_buffer), 0)

    def step(self, action):
        # Repeat action 4 times, max pool over last 2 frames
        frame_buffer = torch.zeros(2, 84, 84, device=self.device)
        reward, done = 0, False
        for t in range(4):
            reward += self.ale.act(self.actions.get(action))
            if t == 2:
                frame_buffer[0] = self._get_state()
            elif t == 3:
                frame_buffer[1] = self._get_state()
            done = self.ale.game_over()
            if done:
                break
        observation = frame_buffer.max(0)[0]
        self.state_buffer.append(observation)
        # Detect loss of life as terminal in training mode
        if self.training:
            lives = self.ale.lives()
            if lives < self.lives and lives > 0:  # Lives > 0 for Q*bert
                self.life_termination = not done  # Only set flag when not truly done
                done = True
            self.lives = lives
        # Return state, reward, done
        return torch.stack(list(self.state_buffer), 0), reward, done

    # Uses loss of life as terminal signal
    def train(self):
        self.training = True

    # Uses standard terminal signal
    def eval(self):
        self.training = False

    def action_space(self):
        return len(self.actions)

    def render(self):
        cv2.imshow("screen", self.ale.getScreenRGB()[:, :, ::-1])
        cv2.waitKey(1)

    def close(self):
        cv2.destroyAllWindows()


class Args:
    def __init__(self, game, seed):
        self.device = torch.device("cuda")
        self.seed = seed
        self.max_episode_length = 108e3
        self.game = game
        self.history_length = 4
