import copy
import re
import sys
from typing import Optional, List, Tuple, Dict

from datasets import Dataset
from torch.nn import CrossEntropyLoss
from torch.utils.data import SequentialSampler, BatchSampler, DataLoader, RandomSampler
from transformers import DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, AutoTokenizer

from config_evaluator import Lazy
from logger import Logger, TqdmLogger
from sip.utils import scale_grad
from sip.meta_loading import load_tsv

import torch
import numpy as np


import random
from sip.eval_utils import MovingAvg

import gc

from sip.utils import get_optimizer


def compute_dl(model, batch):
    """
    Compute DL under model in nats.
    :param model:
    :param batch:
    :return:
    """
    loss_fct = CrossEntropyLoss(ignore_index=-100, reduction="sum")
    r = model(**batch)
    # print(r.logits.shape)
    flattened = r.logits.view(-1, r.logits.size(-1)) #shape (batch*seq length, vocab)
    flat_labels = batch["labels"].view(-1) #shape (batch*seq_length)
    # print(flat_labels.shape)
    return loss_fct(flattened, flat_labels)


def uniform_dl(model, batch):
    """
    Compute uniform DL in nats.
    :param model:
    :param batch:
    :return:
    """
    vocab_size = model.get_output_embeddings().out_features
    return torch.sum(batch["labels"] != -100) * np.log(vocab_size)


def estimate_dl(model, dataset,
                block_endings: list[float],
                max_epochs: int,
                tol: float,
                batch_size: int,
                tokenizer,
                device: str = "cuda:0",
                optimizer: Lazy[torch.optim.Optimizer] = None,
                logger: Optional[Logger] = None,
                grad_scale: Optional[float] = None,
                optimizer_groups: Optional[List[Tuple[str, Dict]]] = None,
                num_accumulation_steps: int = 1,
                randomize_order: bool = True
                ):

    if logger is None:
        logger = Logger()

    model = model.to(device)
    copy_manager = CopyManager(model)

    if randomize_order:
        l = dataset.to_list()
        random.shuffle(l)
        dataset = Dataset.from_list(l)

    block_endings = [int(bs * len(dataset)) for bs in block_endings]
    pointer = 0
    dl = 0
    cumulative_uniform_dl = 0
    has_been_trained = False
    for block_id, block_ending in enumerate(block_endings):
        #Compute description length of current block
        data_part = Dataset.from_dict(dataset[pointer: block_ending])
        ts = SequentialSampler(data_part)
        ts = BatchSampler(ts, batch_size=batch_size, drop_last=False)
        transmit_data_loader = DataLoader(data_part, collate_fn=DataCollatorForSeq2Seq(tokenizer), batch_sampler=ts)
        new_dl = 0
        baseline_dl = 0
        model.eval()
        with torch.no_grad():
            for batch in logger.progress_bar(transmit_data_loader):
                batch = {k: v.to(device) for k, v in batch.items()}
                if not has_been_trained:
                    new_dl += uniform_dl(model, batch).detach().cpu().numpy()
                else:
                    new_dl += compute_dl(model, batch).detach().cpu().numpy()

                baseline_dl += uniform_dl(model, batch).detach().cpu().numpy()

        dl += new_dl
        cumulative_uniform_dl += baseline_dl
        logger.log_metrics("block", {"cumul_dl": dl, "new_dl": new_dl, "new_uniform_dl": baseline_dl,
                                                  "cumul_uniform_dl": cumulative_uniform_dl})
        # logging for new_dl and dl

        pointer = block_ending
        # Train on all data including the current block for encoding the next block
        data_part = Dataset.from_dict(dataset[: pointer])
        if len(data_part) == 0:
            continue
        has_been_trained = True
        ts = RandomSampler(data_part)
        ts = BatchSampler(ts, batch_size=batch_size, drop_last=False)
        train_data_loader = DataLoader(data_part, collate_fn=DataCollatorForSeq2Seq(tokenizer), batch_sampler=ts)

        # Try to delete optimizer and garbage collect it.
        optimizer_obj = None
        gc.collect()

        model = copy_manager.get_a_copy()

        optimizer_obj = get_optimizer(model, optimizer, optimizer_groups)

        model.train()
        step_idx = 0
        epoch = 0
        converged = False
        while epoch < max_epochs and not converged:
            epoch += 1
            loss = 0
            overall_loss = 0
            max_loss = 0
            for batch in logger.progress_bar(train_data_loader):
                step_idx += 1

                # Train the model HERE
                batch = {k: v.to(device) for k, v in batch.items()}
                r = model(**batch)
                new_loss = r.loss.detach().cpu().numpy()
                overall_loss += new_loss
                loss += new_loss
                r.loss.backward()
                if step_idx % num_accumulation_steps == 0:
                    scale_grad(model, grad_scale)
                    optimizer_obj.step()
                    optimizer_obj.zero_grad()
                    logger.log_metrics(f"dl_train_block_{block_id}", {"loss": loss})
                    max_loss = max(max_loss, loss)
                    loss = 0

            # End model training
            if overall_loss/len(train_data_loader) < tol:
            # if max_loss < tol:
                converged = True
            # if overall_loss < tol:
            #     converged = True

    # Compute DL of final model on entire data for reference.
    cm = None
    gc.collect()
    torch.cuda.empty_cache()

    ts = SequentialSampler(dataset)
    ts = BatchSampler(ts, batch_size=batch_size, drop_last=False)
    transmit_data_loader = DataLoader(dataset, collate_fn=DataCollatorForSeq2Seq(tokenizer), batch_sampler=ts)
    new_dl = 0
    model.eval()
    with torch.no_grad():
        for batch in logger.progress_bar(transmit_data_loader):
            batch = {k: v.to(device) for k, v in batch.items()}
            new_dl += compute_dl(model, batch)
    logger.log_metrics("dl", {"dl_final_model": new_dl})
    new_dl = new_dl.detach().cpu().numpy()

    return model, dl, new_dl


def prepare_task_dataset_no_loader(path:str, tokenizer: AutoTokenizer, lenient: bool = False) -> Dataset:
    def mapper(examples):
        d = tokenizer(examples["input"])
        if "output" in examples:
            d["labels"] = tokenizer(text_target=examples["output"])["input_ids"]
        return d

    keys = ["input", "output"]
    d = {k: [] for k in keys}
    for row in load_tsv(path, "input\toutput", lenient=lenient):
        for x, k in zip(row, keys):
            d[k].append(x)
    dataset = Dataset.from_dict(d)

    dataset = dataset.map(mapper, batched=False, remove_columns=["input", "output"])
    return dataset

class CopyManager:
    """
    Makes a single copy of the model at the state when initialized.
    When calling get_a_copy, this state will be restored.
    """
    def __init__(self, model: torch.nn.Module):
        self.model = model
        self.backup_dict = copy.deepcopy(model.state_dict())

    def get_a_copy(self):
        """
        Returns a copy
        :return:
        """
        self.model.load_state_dict(self.backup_dict)
        gc.collect()
        torch.cuda.empty_cache()

        return self.model

if __name__ == "__main__":
    m = AutoModelForSeq2SeqLM.from_pretrained("charsiu/g2p_multilingual_byT5_tiny_16_layers")
    tok = AutoTokenizer.from_pretrained("google/byt5-small")

    d = prepare_task_dataset_no_loader("~/phd/artificial_tasks/meta_adapters/data/10_task.tsv", tok)

    # model, total_dl, last_dl = estimate_dl(model, d, [0.1, 0.2, 0.5, 0.8], 20, 8, tok, device="cpu", logger=TqdmLogger(print=True))
    m, total_dl, last_dl = estimate_dl(m, d, [0.001, 0.1, 0.2, 0.3, 0.9], 15, 0.2, 4, tok, device="cpu", logger=TqdmLogger(print=True))

    print(total_dl)
    print(last_dl)







