import os
import time
import math
import torch
import numpy as np

from lang_hrl.utils.logger import Logger
from gym_minigrid.minigrid import MiniGridEnv

import lang_hrl
from lang_hrl.datasets import crafting_dataset
from functools import partial

class InverseModeling(object):

    def __init__(self, env, network_class, network_kwargs={}, device="cpu",
                 optim_cls=torch.optim.Adam, eval_env=None,
                 dataset=None,
                 validation_dataset=None,
                 optim_kwargs={
                     'lr': 0.0001
                 },
                 batch_size=64,
                 grad_norm=None,
                 checkpoint=None,
                 dataset_fraction=1
                 ):
        
        self.env = env
        num_actions = env.action_space.n
        if device == "auto":
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(device)
        self.batch_size = batch_size
        self.dataset = dataset
        self.validation_dataset = validation_dataset
        self.grad_norm = grad_norm
        self.dataset_fraction = dataset_fraction

        network_args = [num_actions]
        # We might need to parse the dataset type here.
        if isinstance(self.env, lang_hrl.envs.FullyObsLanguageWrapper):
            pass # Nothign special to do here.
        elif isinstance(self.env, lang_hrl.envs.LanguageWrapper):
            pass
        elif isinstance(self.env.unwrapped, lang_hrl.envs.mazebase.MazebaseGame):
            # Now we have to create the vocab here.
            import pickle
            from lang_hrl.datasets import crafting_dataset
            with open(self.dataset +'all_instructions', 'rb') as f:
                all_instructions = pickle.load(f)
            self.vocab, vocab_weights = crafting_dataset.build_vocabulary(all_instructions)
            network_args.append(self.vocab)
            network_args.append(vocab_weights)

        # Create the network and optimizer.
        self.network = network_class(*network_args, **network_kwargs).to(self.device)
        self.optim = optim_cls(self.network.parameters(), **optim_kwargs)

        if checkpoint:
            self.load(checkpoint, initial_lr=optim_kwargs['lr'] if 'lr' in optim_kwargs else 1e-4)
        
        self.criterion = torch.nn.CrossEntropyLoss()
        
    def predict(self, obs, next_obs, batched=False, is_tensor=False):
        if not is_tensor and isinstance(obs, dict):
            obs = {k: torch.from_numpy(v).to(self.device) if isinstance(v, np.ndarray) else v for k,v in obs.items()}
            next_obs = {k: torch.from_numpy(v).to(self.device) if isinstance(v, np.ndarray) else v for k,v in next_obs.items()}
        elif not is_tensor:
            obs = torch.from_numpy(obs).to(self.device)
            next_obs = torch.from_numpy(next_obs).to(self.device)
        if not batched:
            # We need to unsqueeze all of the tensors
            if isinstance(obs, dict):
                obs = {k: v.unsqueeze(0) if isinstance(v, torch.tensor) else v for k,v in obs.items()}
                next_obs = {k: v.unsqueeze(0) if isinstance(v, torch.tensor) else v for k,v in next_obs.items()}

        logits = self.network(obs, next_obs)
        preds =  torch.argmax(logits, dim=-1)

        if not batched:
            preds = preds[0].item()
        return preds
            
    def save(self, path, extension):
        save_dict = {"network" : self.network.state_dict(), "optim": self.optim.state_dict()}
        torch.save(save_dict, os.path.join(path, extension + ".pt"))

    def load(self, checkpoint, initial_lr=1e-4):
        print("LOADING CHECKPOINT:", checkpoint)
        checkpoint = torch.load(checkpoint, map_location=self.device)
        self.network.load_state_dict(checkpoint['network'])
        self.optim.load_state_dict(checkpoint['optim'])
        # make sure that we reset the learning rate in case we decide to not use scheduling for finetuning.
        for param_group in self.optim.param_groups:
            param_group['lr'] = initial_lr

    def train(self, path, total_steps, log_freq=100, eval_freq=5000, eval_ep=0):        
        logger = Logger(path=path)

        print("[Lang RL] Training an inverse model with tunable parameters", sum(p.numel() for p in self.network.parameters() if p.requires_grad))
        # Setup for the different ind
        if isinstance(self.env.env, lang_hrl.envs.FullyObsLanguageWrapper) or isinstance(self.env.env, lang_hrl.envs.LanguageWrapper):
            from lang_hrl.datasets.datasets import InverseModelDataset
            collate_fn = None
            dataset = InverseModelDataset.load(self.dataset, fraction=self.dataset_fraction).to_tensor_dataset()
            if not self.validation_dataset is None:
                validation_dataset = InverseModelDataset.load(self.validation_dataset).to_tensor_dataset()
        elif isinstance(self.env.unwrapped, lang_hrl.envs.mazebase.MazebaseGame):
            from lang_hrl.datasets import crafting_dataset
            dataset = crafting_dataset.CraftingInverseDataset(self.dataset, self.vocab) # Must have created the vocab. Note that it was already given to the agent.
            if not self.validation_dataset is None:
                validation_dataset = crafting_dataset.CraftingInverseDataset(self.validation_dataset, self.vocab)
            collate_fn = None
        else:
            raise ValueError("Unsupported type of env given.")
        
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate_fn)
        if not self.validation_dataset is None:
            validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=self.batch_size, shuffle=False, collate_fn=collate_fn)

        step = 0
        num_epochs = 0
        start_time = time.time()
        ac_losses = []
        best_validation_loss = float('inf')
        while step < total_steps:

            for obs, next_obs, actions in dataloader:

                # Move tensors to GPU
                if isinstance(obs, dict):
                    obs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in obs.items() }
                    next_obs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in next_obs.items() }
                else:
                    obs = obs.to(self.device)
                    next_obs = next_obs.to(self.device)

                actions = actions.to(self.device).long() # Must convert to long

                # Zero the optim and run the model
                self.optim.zero_grad()
                logits = self.network(obs, next_obs)
                if len(actions.shape) == 2 and actions.shape[1] == 1:
                    actions = actions.view(-1)
                loss = self.criterion(logits, actions)
                ac_losses.append(loss.item())                
                # Compute the gradients
                loss.backward()
                if self.grad_norm:
                    torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.grad_norm)
                self.optim.step()
                step += 1

                # Run the train logging
                if step % log_freq == 0:
                    if len(ac_losses) > 0:
                        logger.record("train/loss", np.mean(ac_losses))
                        ac_losses = []
                    logger.record("time/epochs", num_epochs)
                    logger.record("time/steps_per_seceonds", log_freq / (time.time() - start_time))
                    start_time = time.time()
                    logger.dump(step=step)

                # Run validation logging
                if step % eval_freq == 0:
                    self.network.eval()
                    non_empty_dump = False
                    if self.validation_dataset:
                        with torch.no_grad():
                            num_correct, num_preds, valid_ac_loss = 0, 0, 0.0
                            for valid_obs, valid_next_obs, valid_ac in validation_dataloader:
                                if isinstance(obs, dict):
                                    valid_obs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in valid_obs.items()}
                                    valid_next_obs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in valid_next_obs.items()}
                                else:
                                    valid_obs = valid_obs.to(self.device)
                                    valid_next_obs = valid_next_obs.to(self.device)
                                valid_ac = valid_ac.to(self.device).long()

                                logits = self.network(valid_obs, valid_next_obs)
                                if len(valid_ac.shape) == 2 and valid_ac.shape[1] == 1:
                                    valid_ac = valid_ac.view(-1)
                                ac_loss = self.criterion(logits, valid_ac)
                                pred = torch.argmax(logits, axis=-1)
                                num_correct += (pred == valid_ac).sum().item()
                                num_preds += (valid_ac != -100).sum().item()
                                valid_ac_loss += ac_loss.item()


                        validation_loss = valid_ac_loss / len(validation_dataloader)
                        # Save the best model according to our current objective
                        if validation_loss < best_validation_loss:
                            self.save(path, "best_model")
                            best_validation_loss = best_validation_loss

                        if valid_ac_loss > 0.0:
                            logger.record("eval/loss", validation_loss)
                            logger.record("eval/accuracy", num_correct / num_preds)
                            non_empty_dump = True

                    else:
                        # There is no validation dataset. So, we should just save the model every eval freq.
                        self.save(path, "best_model")

                    # Every eval period also save the "final model"
                    self.save(path, "final_model")

                    if non_empty_dump:
                        logger.dump(step=step)
                    self.network.train()
                # End validation section

                if step >= total_steps:
                    break
            
            num_epochs += 1

        print("Finished.")
        self.save(path, "final_model")

