import math
import warnings
import json
import re
from pathlib import Path
import numpy as np
from collections import defaultdict
from tqdm import tqdm
import torch
import torch.distributed as dist
from torch import nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader, WeightedRandomSampler, Dataset, Sampler
import datasets
from datasets import IterableDataset
from typing import Dict, List, Optional, Sequence
from dataclasses import dataclass
import transformers
from transformers import Trainer
from transformers.utils import ExplicitEnum, is_torch_tpu_available
from transformers.optimization import get_scheduler
from transformers.utils import logging
from transformers.trainer import is_sagemaker_mp_enabled
from transformers.modeling_utils import unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.trainer_utils import (
    has_length,
    denumpify_detensorize,
    EvalLoopOutput,
    enable_full_determinism,
    set_seed,
    get_last_checkpoint,
    PREFIX_CHECKPOINT_DIR,
    IntervalStrategy,
)
from transformers.trainer_pt_utils import find_batch_size, get_dataloader_sampler
from transformers.trainer_utils import seed_worker, has_length
from transformers.utils import is_datasets_available, is_peft_available


logger = logging.get_logger(__name__)
IGNORE_INDEX = -100


class MultipleDataset(Dataset):
    def __init__(self, dataset_dict) -> None:
        super().__init__()
        self.data_names = list(dataset_dict.keys())
        self.datasets = list(dataset_dict.values())
        self.dataset_lens = [len(dataset) for dataset in self.datasets]

    def __getitem__(self, index):
        dataset_idx, sample_idx = index
        return self.datasets[dataset_idx][sample_idx]

    def __len__(self):
        return sum(self.dataset_lens)

    def get_subset_len(self):
        return self.dataset_lens


@dataclass
class DataCollatorForDynamicWeighting(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple(
            [instance[key] for instance in instances] for key in ("input_ids", "labels")
        )
        input_ids = [torch.tensor(x) for x in input_ids]
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = [torch.tensor(x) for x in labels]
        labels = torch.nn.utils.rnn.pad_sequence(
            labels, batch_first=True, padding_value=IGNORE_INDEX
        )
        domain_ids = [instance["domain_ids"] for instance in instances]
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
            domain_ids=torch.tensor(domain_ids),
        )


class DynamicWeightedSampler(Sampler):
    def __init__(
        self,
        dataset_lens: Sequence[Dataset],
        replacement: bool = True,
        uniform_init: bool = False,
        no_shuffle: bool = False,
    ) -> None:
        self.dataset_lens = dataset_lens
        self.num_samples = sum(dataset_lens)
        self.no_shuffle = no_shuffle
        if self.no_shuffle:
            self.pointers = [0] * len(dataset_lens)
        if uniform_init:
            self.weights = torch.tensor([1 / len(dataset_lens)] * len(dataset_lens))
        else:
            self.weights = torch.tensor(
                [dataset_len / self.num_samples for dataset_len in dataset_lens]
            )
        print("Init weights", self.weights)
        self.replacement = replacement

    def update_weights(self, weights: Sequence[float]) -> None:
        self.weights = weights
        self.dataset_indices = torch.multinomial(
            self.weights, self.num_samples, replacement=self.replacement
        )
        if dist.get_rank() == 0:
            print("Cur weights", self.weights)

    def __iter__(self):
        iter_cnt = 0
        self.dataset_indices = torch.multinomial(
            self.weights, self.num_samples, replacement=self.replacement
        )
        # print('\n\n\n', self.dataset_indices.device, '\n\n\n')
        # exit(0)
        while iter_cnt < self.num_samples:
            dataset_idx = self.dataset_indices[iter_cnt]
            if self.no_shuffle:
                sample_idx = self.pointers[dataset_idx]
                self.pointers[dataset_idx] = (
                    self.pointers[dataset_idx] + 1
                ) % self.dataset_lens[dataset_idx]
            else:
                sample_idx = torch.randint(self.dataset_lens[dataset_idx], (1,)).item()
            yield (dataset_idx.item(), sample_idx)
            iter_cnt += 1

    def __len__(self):
        return self.num_samples


if is_peft_available():
    from peft import PeftModel


class DynamicWeightTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.domain_list = self.train_dataset.data_names
        self.last_per_domain_losses = None
        print("domain_list", self.domain_list)
        self.eps = 1 / len(self.domain_list)
        self.prev_eps = None
        self.iteration = 0
        self._estimated_eval_loss = torch.zeros(len(self.domain_list))

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        # print("Sampler", self.train_dataloader.sampler)
        # print("Sampler", type(self.train_dataloader.sampler))
        # print("Weight", self.train_dataloader.sampler.weights)
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        if "domain_ids" in inputs:
            domain_ids = inputs.pop("domain_ids")
        outputs = model(**inputs)
        inputs["domain_ids"] = domain_ids
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            unwrapped_model = unwrap_model(model)
            if is_peft_available() and isinstance(unwrapped_model, PeftModel):
                model_name = unwrapped_model.base_model.model._get_name()
            else:
                model_name = unwrapped_model._get_name()
            if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss

    def get_train_dataloader(self) -> DataLoader:
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")

        train_dataset = self.train_dataset
        data_collator = self.data_collator
        # if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
        #     train_dataset = self._remove_unused_columns(
        #         train_dataset, description="training"
        #     )
        # else:
        #     data_collator = self._get_collator_with_removed_columns(
        #         data_collator, description="training"
        #     )

        dataloader_params = {
            "batch_size": self._train_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "persistent_workers": self.args.dataloader_persistent_workers,
        }

        if not isinstance(train_dataset, torch.utils.data.IterableDataset):
            dataloader_params["sampler"] = DynamicWeightedSampler(
                self.train_dataset.get_subset_len(),
                replacement=True,
                uniform_init=self.args.domain_weight_uniform_init,
                no_shuffle=self.args.no_shuffle,
            )
            dataloader_params["drop_last"] = self.args.dataloader_drop_last
            dataloader_params["worker_init_fn"] = seed_worker

        self.sampler = dataloader_params["sampler"]
        self.train_dataloader = self.accelerator.prepare(
            DataLoader(train_dataset, **dataloader_params)
        )
        print("Sample", get_dataloader_sampler(self.train_dataloader))
        return self.train_dataloader

    def evaluation_loop(
        self,
        dataloader,
        description,
        prediction_loss_only=None,
        ignore_keys=None,
        metric_key_prefix="eval",
    ):
        """
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.

        Computes per-domain log-perplexity, uniformly averaged log-perplexity, and worst-case log-perplexity
        """
        args = self.args

        if prediction_loss_only:
            # hack - don't do prediction loss only
            prediction_loss_only = None

        # if eval is called w/o train init deepspeed here
        if args.deepspeed and not self.deepspeed:
            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
            # from the checkpoint eventually
            deepspeed_engine, _, _ = deepspeed_init(
                self, num_training_steps=0, resume_from_checkpoint=None, inference=True
            )
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine

        model = self._wrap_model(self.model, training=False, dataloader=dataloader)

        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)

        batch_size = self.args.eval_batch_size

        logger.info(f"***** Running {description} *****")
        if has_length(dataloader):
            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        else:
            logger.info("  Num examples: Unknown")
        logger.info(f"  Batch size = {batch_size}")

        model.eval()

        self.callback_handler.eval_dataloader = dataloader

        if is_torch_tpu_available():
            dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(
                args.device
            )

        if args.past_index >= 0:
            self._past = None

        loss_fn = nn.CrossEntropyLoss(reduction="sum")

        losses = torch.zeros(len(self.domain_list)).cuda()
        tokencounts = torch.zeros(len(self.domain_list)).cuda()
        examplecounts = torch.zeros(len(self.domain_list)).cuda()
        observed_num_examples = 0
        # Main evaluation loop
        for step, inputs in tqdm(enumerate(dataloader)):
            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size
                # For batch samplers, batch_size is not known by the dataloader in advance.
                if batch_size is None:
                    batch_size = observed_batch_size

            # Prediction step
            loss, logits, labels = self.prediction_step(
                model, inputs, prediction_loss_only, ignore_keys=ignore_keys
            )
            domain_ids = inputs["domain_ids"].to(loss.device)

            if is_torch_tpu_available():
                xm.mark_step()

            if isinstance(logits, tuple):
                logits = logits[0]

            # compute losses per domain
            for domain_idx, domain_name in enumerate(self.domain_list):
                domain_mask = domain_ids == domain_idx
                examplecounts[domain_idx] = (
                    examplecounts[domain_idx] + domain_mask.sum()
                )

                if domain_mask.sum() > 0:
                    domain_labels = labels[domain_mask]
                    domain_preds = logits[domain_mask]
                    domain_labels = domain_labels[:, 1:].contiguous().view(-1)
                    domain_preds = (
                        domain_preds[:, :-1, :]
                        .contiguous()
                        .view(-1, domain_preds.size(-1))
                    )
                    losses[domain_idx] = losses[domain_idx] + loss_fn(
                        domain_preds, domain_labels
                    )
                    tokencounts[domain_idx] = (
                        tokencounts[domain_idx] + (domain_labels != -100).sum()
                    )

        torch.distributed.all_reduce(losses)
        torch.distributed.all_reduce(tokencounts)
        torch.distributed.all_reduce(examplecounts)

        # losses/preds/labels on CPU (final containers)
        per_domain_losses = {
            domain_name: losses[domain_idx].item()
            for domain_idx, domain_name in enumerate(self.domain_list)
            if tokencounts[domain_idx] > 0
        }
        per_domain_tokencounts = {
            domain_name: tokencounts[domain_idx].item()
            for domain_idx, domain_name in enumerate(self.domain_list)
            if tokencounts[domain_idx] > 0
        }
        per_domain_examplecounts = {
            domain_name: examplecounts[domain_idx].item()
            for domain_idx, domain_name in enumerate(self.domain_list)
            if tokencounts[domain_idx] > 0
        }

        # normalize
        per_domain_losses = {
            domain_name: per_domain_losses[domain_name]
            / per_domain_tokencounts[domain_name]
            for domain_name in per_domain_losses.keys()
        }

        metrics = {
            f"{domain_name}:log_perplexity": per_domain_losses[domain_name]
            for domain_name in per_domain_losses.keys()
        }
        metrics["uniform_avg_log_perplexity"] = np.mean(
            list(per_domain_losses.values())
        )
        metrics["worst_case_log_perplexity"] = np.amax(list(per_domain_losses.values()))

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

        self.update_sample_weights(per_domain_losses)

        return EvalLoopOutput(
            predictions=None,
            label_ids=None,
            metrics=metrics,
            num_samples=sum(list(per_domain_examplecounts.values())),
        )

    def update_sample_weights(self, per_domain_losses):
        if self.last_per_domain_losses is None:
            self.last_per_domain_losses = per_domain_losses
            return

        c = 1e-4  # following Doremi (Xie et al., 2023)

        sampler = get_dataloader_sampler(self.train_dataloader)
        print(sampler)
        # assert isinstance(sampler, DynamicWeightedSampler)
        last_weights: torch.Tensor = sampler.weights

        # update weights
        new_weights = last_weights.clone().detach()
        per_domain_losses_tensor = torch.tensor(list(per_domain_losses.values()))
        last_per_domain_losses_tensor = torch.tensor(
            list(self.last_per_domain_losses.values())
        )
        if self.args.domain_weight_last_first:
            diff = last_per_domain_losses_tensor - per_domain_losses_tensor
        else:
            diff = per_domain_losses_tensor - last_per_domain_losses_tensor
        if self.args.domain_weight_norm:
            diff /= last_per_domain_losses_tensor
        if self.args.domain_weight_update_type == "exp":
            # new_weights = last_weights * torch.exp(
            #     self.args.domain_weight_lr * torch.clamp(diff, min=0)
            # )
            new_weights = last_weights * torch.exp(self.args.domain_weight_lr * diff)
        elif self.args.domain_weight_update_type == "doremi":
            updated_alpha = torch.log(last_weights) + self.args.domain_weight_lr * diff
            updated_alpha = F.softmax(updated_alpha, dim=0)
            new_weights = (1 - c) * updated_alpha + c / len(per_domain_losses)
        elif self.args.domain_weight_update_type == "bandit":
            updated_alpha = last_weights + self.args.domain_weight_lr * diff
            updated_alpha = F.softmax(updated_alpha, dim=0)
            new_weights = (1 - c) * updated_alpha + c / len(per_domain_losses)
        elif self.args.domain_weight_update_type == "delta_exp_softmax":
            new_weights = F.softmax(self.args.domain_weight_lr * diff, dim=0)
        elif self.args.domain_weight_update_type == "delta_exp_softmax_update":
            alpha = F.softmax(self.args.domain_weight_lr * diff, dim=0)
            new_weights = alpha * last_weights
            # new_weights = diff / diff.sum()
        elif self.args.domain_weight_update_type == "delta_exp_update":
            alpha = torch.clamp(diff, min=0)
            alpha = alpha / alpha.sum()
            new_weights = last_weights * F.softmax(
                self.args.domain_weight_lr * alpha, dim=0
            )
        elif "odm" in self.args.domain_weight_update_type:
            if self.args.domain_weight_update_type == "odm":
                reward = per_domain_losses_tensor
            else:
                reward = diff
            self.prev_eps = self.eps
            self.iteration += 1
            self.eps = min(
                1 / len(per_domain_losses),
                math.sqrt(
                    math.log(len(per_domain_losses))
                    / (len(per_domain_losses) * self.state.global_step)
                ),
            )
            print("eps", self.eps)
            self._estimated_eval_loss = (
                self.args.domain_weight_smoothing_factor * self._estimated_eval_loss
                + (1 - self.args.domain_weight_smoothing_factor) * torch.exp(reward)
            )
            print("_estimated_eval_loss", self._estimated_eval_loss)
            total_estimated_rewards = torch.sum(
                self._estimated_eval_loss * self.prev_eps
            )
            print("total_estimated_rewards", total_estimated_rewards)
            scaling_factor = (
                1 - len(per_domain_losses) * self.eps
            ) / total_estimated_rewards
            print("scaling_factor", scaling_factor)
            new_weights = (
                torch.exp(self._estimated_eval_loss * self.prev_eps) * scaling_factor
                + self.eps
            )
            print("new_weights", new_weights)
        # norm weights
        new_weights = new_weights / new_weights.sum()

        # if last weights very close to new weights, don't update
        if torch.allclose(last_weights, new_weights, atol=1e-4):
            self.args.evaluation_strategy = IntervalStrategy.NO

        if dist.get_rank() == 0:
            print("Last losses:", self.last_per_domain_losses)
            print("Current losses:", per_domain_losses)
            print("Last weights:", last_weights)
            print("New weights:", new_weights)

        sampler.update_weights(new_weights)
        self.last_per_domain_losses = per_domain_losses
