import os
import time
import math
import torch
import numpy as np

from lang_hrl.utils.logger import Logger
from lang_hrl.utils.evaluate import eval_policy

from gym_minigrid.minigrid import MiniGridEnv

import lang_hrl
from functools import partial

def compute_curl_loss(anchor, target, proj_mat, actions, within_seq=False):
    
    if within_seq and len(anchor.shape) == 3: # We are in the sequence case
        B, S, D = anchor.shape
        # Compute CURL loss for each sequence individually using bmm
        proj_mat = proj_mat.unsqueeze(0).expand(B, -1, -1)
        Wz = torch.bmm(proj_mat, target.transpose(1, 2))
        logits = torch.bmm(anchor, Wz)
        logits = logits - torch.max(logits, 2)[0].unsqueeze(-1) # Shape (B, S, S)
        with torch.no_grad():
            labels = torch.arange(S, dtype=torch.long, device=logits.device).unsqueeze(0).expand(B, -1)
            # Labels should now be the same shape as actions
            labels[actions == -100] = -100
        # Reshape everything to be flat
        logits = logits.reshape(-1, logits.size(-1))
        labels = labels.reshape(-1)

        return torch.nn.functional.cross_entropy(logits, labels, ignore_index=-100)
    
    if len(anchor.shape) == 3:
        # Flatten everything
        B, S, D = anchor.shape
        anchor = anchor.reshape(B*S, D)
        target = target.reshape(B*S, D)
        actions = actions.reshape(B*S,)
    
    # We are in the regular case now where we compute the CURL loss between everything
    Wz = torch.matmul(proj_mat, target.T)
    logits = torch.matmul(anchor, Wz)
    logits = logits - torch.max(logits, 1)[0][:, None]
    with torch.no_grad():
        labels = torch.arange(logits.shape[0], dtype=torch.long, device=logits.device)
        labels[actions == -100] = -100
    return torch.nn.functional.cross_entropy(logits, labels, ignore_index=-100)

def compute_byol_loss(anchor, target, actions):
    # Flatten everything if it is not already
    if len(anchor.shape) == 3:
        B, S, C = anchor.shape
        anchor = anchor.reshape(B*S, C)
        target = target.reshape(B*S, C)
        actions = actions.reshape(B*S,)
    anchor = torch.nn.functional.normalize(anchor, dim=-1, p=2)
    target = torch.nn.functional.normalize(target, dim=-1, p=2)
    loss = 2 - 2*(anchor * target).sum(dim=-1)
    loss[actions == -100] = 0 # mask the 
    return torch.mean(loss)

def ema_parameters(net, target_net, tau):
    for param, target_param in zip(net.parameters(), target_net.parameters()):
        target_param.data.copy_(tau * param.data + (1-tau) * target_param.data)

class BehaviorCloningMultiple(object):

    def __init__(self, env, network_class, network_kwargs={}, device="cpu",
                 optim_cls=torch.optim.Adam, eval_env=None,
                 datasets=None,
                 validation_dataset=None,
                 optim_kwargs={
                     'lr': 0.0001
                 },
                 batch_size=64,
                 grad_norm=None,
                 schedule=None,
                 schedule_kwargs={},
                 pretraining=False,
                 lang_coeff=1,
                 unsup_coeff=1,
                 unsup_within_seq=False,
                 unsup_type=None,
                 unsup_ema_tau=0.005,
                 unsup_ema_update_freq=2,
                 checkpoint=None,
                 dataset_fraction=1,
                 ):
        
        self.env = env
        self.eval_env = eval_env
        num_actions = env.action_space.n
        if device == "auto":
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(device)
        self.batch_size = batch_size
        self.datasets = datasets
        self.validation_dataset = validation_dataset
        self.grad_norm = grad_norm
        self.pretraining = pretraining
        self.lang_coeff = lang_coeff
        self.unsup_within_seq = unsup_within_seq
        self.unsup_coeff = unsup_coeff
        self.unsup_ema_update_freq = unsup_ema_update_freq
        self.unsup_ema_tau = unsup_ema_tau
        self.unsup_type = unsup_type
        self.dataset_fraction = dataset_fraction

        network_args = [num_actions]
        # We might need to parse the dataset type here.
        if isinstance(self.env, lang_hrl.envs.FullyObsLanguageWrapper):
            pass # Nothign special to do here.
        elif isinstance(self.env, lang_hrl.envs.LanguageWrapper):
            pass
        elif isinstance(self.env.unwrapped, lang_hrl.envs.mazebase.MazebaseGame):
            # Now we have to create the vocab here.
            import pickle
            with open(self.datasets[0] +'all_instructions', 'rb') as f:
                all_instructions = pickle.load(f)
            from lang_hrl.datasets import crafting_dataset
            self.vocab, vocab_weights = crafting_dataset.build_vocabulary(all_instructions)
            network_args.append(self.vocab)
            network_args.append(vocab_weights)

        # Create the network and optimizer.
        self.network = network_class(*network_args, **network_kwargs).to(self.device)
        self.optim = optim_cls(self.network.parameters(), **optim_kwargs)

        if checkpoint:
            self.load(checkpoint, initial_lr=optim_kwargs['lr'] if 'lr' in optim_kwargs else 1e-4)
        
        self.criterion = torch.nn.CrossEntropyLoss()

        if self.unsup_coeff > 0.0: # create the target networks etc.
            self.ema_network = network_class(*network_args, **network_kwargs).to(self.device)
            for param in self.ema_network.parameters():
                param.requires_grad = False

        if schedule:
            assert 'lr' in optim_kwargs, "Must specify a learning rate"
            self.initial_lr = optim_kwargs['lr']
            self.scheduler = schedule
            self.scheduler_kwargs = schedule_kwargs
        else:
            self.scheduler = None
        
    def predict(self, obs, deterministic=True, history=None):
        if isinstance(obs, dict):
            obs = {k: torch.from_numpy(v).to(self.device).unsqueeze(0) if isinstance(v, np.ndarray) else v for k,v in obs.items()}
            if not history is None:
                history = {k: torch.from_numpy(v).to(self.device).unsqueeze(0) if isinstance(v, np.ndarray) else v for k,v in history.items()}
        else:
            obs = torch.from_numpy(obs).to(self.device).unsqueeze(0)
            if not history is None:
                history = torch.from_numpy(history).to(self.device).unsqueeze(0)

        if hasattr(self.network, "predict"):
            # Support custom predict methods if a network has it
            return self.network.predict(obs, deterministic=deterministic, history=history)
        else:
            logits, _, _ = self.network(obs) # Remove any aux loss
            logits = logits[0] # remove the batch dim
            if deterministic:
                return torch.argmax(logits).item()
            else:
                probs = torch.softmax(logits, dim=-1)
                dist = torch.distributions.categorical.Categorical(probs)
                return dist.sample().item()
    
    def save(self, path, extension):
        save_dict = {"network" : self.network.state_dict(), "optim": self.optim.state_dict()}
        torch.save(save_dict, os.path.join(path, extension + ".pt"))

    def load(self, checkpoint, initial_lr=1e-4, strict=True):
        print("LOADING CHECKPOINT:", checkpoint)
        checkpoint = torch.load(checkpoint, map_location=self.device)
        self.network.load_state_dict(checkpoint['network'], strict=strict)
        if strict:
            # Only load the optimizer state dict if we are being strict.
            self.optim.load_state_dict(checkpoint['optim'])
        # make sure that we reset the learning rate in case we decide to not use scheduling for finetuning.
        for param_group in self.optim.param_groups:
            param_group['lr'] = initial_lr

    def train(self, path, total_steps, log_freq=100, eval_freq=5000, eval_ep=0, validation_metric="loss", use_eval_mode=True):        
        logger = Logger(path=path)

        print("[Lang RL] Training a model with tunable parameters", sum(p.numel() for p in self.network.parameters() if p.requires_grad))
        # Setup for the different ind
        if isinstance(self.env.env, lang_hrl.envs.FullyObsLanguageWrapper):
            from lang_hrl.datasets.datasets import BehaviorCloningDataset
            collate_fn = None
            datasets = []
            for dataset in self.datasets:
                datasets.append(BehaviorCloningDataset.load(dataset, fraction=self.dataset_fraction).to_tensor_dataset())
            if not self.validation_dataset is None:
                validation_dataset = BehaviorCloningDataset.load(self.validation_dataset).to_tensor_dataset()

        elif isinstance(self.env.env, lang_hrl.envs.LanguageWrapper):
            from lang_hrl.datasets.datasets import BabyAITrajectoryDataset
            collate_fn = lang_hrl.datasets.datasets.traj_collate_fn
            datasets = []
            for dataset in self.datasets:
                datasets.append(BabyAITrajectoryDataset.load(dataset, fraction=self.dataset_fraction))
            if not self.validation_dataset is None:
                validation_dataset = BabyAITrajectoryDataset.load(self.validation_dataset)

        elif isinstance(self.env.unwrapped, lang_hrl.envs.mazebase.MazebaseGame):
            from lang_hrl.datasets import crafting_dataset
            if self.unsup_coeff > 0.0:
                skip = 1 # Set skip to 1, used to be 3.
            else:
                skip = -1
            datasets = []
            for dataset in self.datasets:
                datasets.append(crafting_dataset.CraftingDataset(dataset, self.vocab, dataset_fraction=self.dataset_fraction, skip=skip)) # Must have created the vocab. Note that it was already given to the agent.
            if not self.validation_dataset is None:
                validation_dataset = crafting_dataset.CraftingDataset(self.validation_dataset, self.vocab, dataset_fraction=self.dataset_fraction, skip=skip)
            collate_fn = partial(crafting_dataset.collate_fn, vocab_size=len(self.vocab))
        else:
            raise ValueError("Unknown environment type passed in.")

        dataloaders = []
        for dataset in datasets:
            dataloaders.append(torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate_fn))

        if not self.validation_dataset is None:
            validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=self.batch_size, shuffle=False, collate_fn=collate_fn)

        step = 0
        num_epochs = 0
        start_time = time.time()
        bc_losses = []
        lang_losses = []
        unsup_losses = []
        best_validation_accuracy = -1
        best_validation_loss = float('inf')
        while step < total_steps:
            
            iterators = [iter(dataloader) for dataloader in dataloaders]
            reset_iterators = False
            while not reset_iterators:
                
                self.optim.zero_grad()
                losses = []
                for iterator_idx, iterator in enumerate(iterators):
                    try: 
                        obs, actions = next(iterator)
                    except StopIteration as e:
                        reset_iterators = True
                        break

                    # Move tensors to GPU
                    if isinstance(obs, dict):
                        obs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in obs.items() }
                    else:
                        obs = obs.to(self.device)
                    actions = actions.to(self.device).long() # Must convert to long

                    if iterator_idx > 0: # Auto Remove language labels for the second dataset!
                        language_labels = None
                    else:
                        language_labels = obs['label'] if 'label' in obs and self.lang_coeff > 0 else None

                    logits, lang_loss, anchor_logits = self.network(obs, labels=language_labels, is_target=False)
                
                    loss = 0
                    if not language_labels is None and not lang_loss is None and self.lang_coeff > 0:
                        lang_losses.append(lang_loss.item())
                        loss = self.lang_coeff * lang_loss

                    if not anchor_logits is None and self.unsup_coeff > 0.0:
                        with torch.no_grad():
                            _, _, target_logits = self.ema_network(obs, labels=None, is_target=True)
                            target_logits = target_logits.detach()
                        # Now compute the curl loss
                        if self.unsup_type == "atc":
                            unsup_loss = compute_curl_loss(anchor_logits, target_logits, self.network.unsup_proj, actions, within_seq=self.unsup_within_seq)
                        elif self.unsup_type == "byol":
                            loss_one = compute_byol_loss(anchor_logits, target_logits, actions)
                            _, _,anchor_logits_rev = self.network(obs, labels=None, is_target=True)
                            with torch.no_grad():
                                _, _, target_logits_rev = self.ema_network(obs, labels=None, is_target=False)
                            target_logits_rev = target_logits_rev.detach()
                            loss_two = compute_byol_loss(anchor_logits_rev, target_logits_rev, actions)
                            unsup_loss = loss_one + loss_two

                        unsup_losses.append(unsup_loss.item())
                        loss = loss + self.unsup_coeff * unsup_loss
                        
                    if not self.pretraining:
                        if len(logits.shape) > 2:
                            logits = logits.reshape(-1, logits.size(-1))
                            actions = actions.reshape(-1)
                        bc_loss = self.criterion(logits, actions)
                        bc_losses.append(bc_loss.item())
                        loss = bc_loss + loss

                    losses.append(loss)

                if reset_iterators:
                    continue

                # If we don't reset the iterators, compute the gradients
                total_loss = sum(losses) / len(losses)
                total_loss.backward()
                if self.grad_norm:
                    torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.grad_norm)
                self.optim.step()
                step += 1

                # Update the learning rates
                # So far only the GPT scheduler from minGPT is used.
                if self.scheduler and self.scheduler == "gpt":
                    if step < self.scheduler_kwargs['warmup']:
                        lr_mult = float(step) / float(max(1, self.scheduler_kwargs['warmup']))
                    else:
                        progress = float(step - self.scheduler_kwargs['warmup']) / \
                                   float(max(1, total_steps - self.scheduler_kwargs['warmup']))
                        lr_mult = max(0.1, 0.5 * (1 + math.cos(math.pi * progress)))
                    lr = self.initial_lr * lr_mult
                    for param_group in self.optim.param_groups:
                        param_group['lr'] = lr
                elif not self.scheduler is None:
                    raise ValueError("Used unimplemented scheduler.")

                # Run the train logging
                if step % log_freq == 0:
                    if len(bc_losses) > 0:
                        logger.record("train/bc_loss", np.mean(bc_losses))
                        bc_losses = []
                    if len(lang_losses) > 0:
                        logger.record("train/lang_losses", np.mean(lang_losses))
                        lang_losses = []
                    if len(unsup_losses) > 0:
                        logger.record("train/unsup_loss", np.mean(unsup_losses))
                    logger.record("time/epochs", num_epochs)
                    logger.record("time/steps_per_seceonds", log_freq / (time.time() - start_time))
                    start_time = time.time()
                    logger.dump(step=step)

                # Run validation logging
                if step % eval_freq == 0:
                    if use_eval_mode:
                        self.network.eval()
                    non_empty_dump = False
                    if eval_ep > 0 and not self.env is None and not self.pretraining:
                        with torch.no_grad():
                            mean_reward, std_reward, mean_length, success_rate = eval_policy(self.env, self, eval_ep)
                        logger.record("eval/mean_reward", mean_reward)
                        logger.record("eval/reward_std", std_reward)
                        logger.record("eval/mean_length", mean_length)
                        logger.record("eval/success_rate", success_rate)
                        non_empty_dump = True

                    if eval_ep > 0 and not self.eval_env is None and not self.pretraining:
                        with torch.no_grad():
                            mean_reward, std_reward, mean_length, success_rate = eval_policy(self.eval_env, self, eval_ep)
                        logger.record("eval_env/mean_reward", mean_reward)
                        logger.record("eval_env/reward_std", std_reward)
                        logger.record("eval_env/mean_length", mean_length)
                        logger.record("eval_env/success_rate", success_rate)
                        non_empty_dump = True
                    
                    if self.validation_dataset:
                        with torch.no_grad():
                            num_correct, num_preds, valid_bc_loss, valid_lang_loss = 0, 0, 0.0, []
                            valid_unsup_loss = []
                            for valid_obs, valid_ac in validation_dataloader:
                                if isinstance(obs, dict):
                                    valid_obs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in valid_obs.items()}
                                else:
                                    valid_obs = valid_obs.to(self.device)
                                valid_ac = valid_ac.to(self.device).long()
                                valid_lang_labels = valid_obs['label'] if 'label' in valid_obs and self.lang_coeff > 0 else None

                                logits, lang_loss, anchor_logits = self.network(valid_obs, labels=valid_lang_labels, is_target=False)
                                
                                if not valid_lang_labels is None and not lang_loss is None and self.lang_coeff > 0:
                                    valid_lang_loss.append(lang_loss.item())
                                
                                if not anchor_logits is None and self.unsup_coeff > 0.0:
                                    with torch.no_grad():
                                        _, _, target_logits = self.ema_network(valid_obs, labels=None, is_target=True)
                                    if self.unsup_type == "atc":
                                        unsup_loss = compute_curl_loss(anchor_logits, target_logits, self.network.unsup_proj, valid_ac, within_seq=self.unsup_within_seq)
                                    elif self.unsup_type == "byol":
                                        loss_one = compute_byol_loss(anchor_logits, target_logits, valid_ac)
                                        _, _, anchor_logits_rev = self.network(valid_obs, labels=None, is_target=True)
                                        with torch.no_grad():
                                            _, _, target_logits_rev = self.ema_network(valid_obs, labels=None, is_target=False)
                                        target_logits_rev = target_logits_rev.detach()
                                        loss_two = compute_byol_loss(anchor_logits_rev, target_logits_rev, valid_ac)
                                        unsup_loss = loss_one + loss_two
                                    valid_unsup_loss.append(unsup_loss.item())

                                if not self.pretraining:
                                    if len(logits.shape) > 2:
                                        logits = logits.reshape(-1, logits.size(-1))
                                        valid_ac = valid_ac.reshape(-1)
                                    bc_loss = self.criterion(logits, valid_ac)
                                    pred = torch.argmax(logits, axis=-1)
                                    num_correct += (pred == valid_ac).sum().item()
                                    num_preds += (valid_ac != -100).sum().item()
                                    valid_bc_loss += bc_loss.item()

                        if not self.pretraining:
                            validation_loss = valid_bc_loss / len(validation_dataloader)
                        else:
                            validation_loss = 0
                            if self.lang_coeff > 0:
                                validation_loss += self.lang_coeff * np.mean(valid_lang_loss)
                            if self.unsup_coeff > 0:
                                validation_loss += self.unsup_coeff * np.mean(valid_unsup_loss)
                        
                        if not self.pretraining:
                            validation_accuracy = num_correct / num_preds
                        else:
                            validation_accuracy = 0.0

                        if validation_loss < best_validation_loss:
                            best_validation_loss = validation_loss
                            if validation_metric == "loss":
                                self.save(path, "best_model")
                        if validation_accuracy > best_validation_accuracy:
                            best_validation_accuracy = validation_accuracy
                            if validation_metric == "accuracy":
                                self.save(path, "best_model")

                        if valid_bc_loss > 0.0:
                            logger.record("eval/bc_loss", valid_bc_loss / len(validation_dataloader))
                            non_empty_dump = True
                        if validation_accuracy > 0.0:
                            logger.record("eval/accuracy", validation_accuracy)
                            non_empty_dump = True
                        if len(valid_lang_loss) > 0:
                            logger.record("eval/lang_loss", np.mean(valid_lang_loss))
                            non_empty_dump = True
                        if len(valid_unsup_loss) > 0:
                            logger.record("eval/unsup_loss", np.mean(valid_unsup_loss))
                            non_empty_dump = True
                        print("Finished Eval!")
                    else:
                        # There is no validation dataset. So, we should just save the model every eval freq.
                        self.save(path, "best_model")

                    # Every eval period also save the "final model"
                    self.save(path, "final_model")

                    if non_empty_dump:
                        logger.dump(step=step, dump_csv=True) # Dump to csv after the eval runs
                    self.network.train()
                # End validation section

                # If we are using CURL we need to update the EMA network
                # Note that this is done AFTER evaluation on purpose so the unsup loss relfects what was seen
                if self.unsup_coeff > 0.0 and step % self.unsup_ema_update_freq == 0:
                    ema_parameters(self.network, self.ema_network, self.unsup_ema_tau)

                if step >= total_steps:
                    break
            
            num_epochs += 1

        print("Finished.")
        self.save(path, "final_model")

