import numpy as np
import torch
from torch.nn import functional as F

from tqdm.auto import tqdm

from ddlm.modeling.diffusion import (
    DiffusionTransformer,
)
from ddlm.time.low_discrepency_sampling import get_t

task_to_labels = {
    "cola": {0: "0", 1: "1"},
    "sst2": {0: "0", 1: "1"},
    "mrpc": {0: "0", 1: "1"},
    "qqp": {0: "0", 1: "1"},
    "mnli": {0: "entailment", 1: "neutral", 2: "contradiction"},
    "mnli-mm": {0: "entailment", 1: "neutral", 2: "contradiction"},
    "qnli": {0: "entailment", 1: "not_entailment"},
    "rte": {0: "entailment", 1: "not_entailment"},
}

task_to_filename = {
    "cola": "CoLA.tsv",
    "sst2": "SST-2.tsv",
    "mrpc": "MRPC.tsv",
    "qqp": "QQP.tsv",
    "mnli": "MNLI-m.tsv",
    "mnli-mm": "MNLI-mm.tsv",
    "qnli": "QNLI.tsv",
    "rte": "RTE.tsv",
    "stsb": "STS-B.tsv",
}

arcnames = [
    "CoLA.tsv",
    "SST-2.tsv",
    "MRPC.tsv",
    "QQP.tsv",
    "MNLI-m.tsv",
    "MNLI-mm.tsv",
    "QNLI.tsv",
    "RTE.tsv",
    "STS-B.tsv",
]
dummy_files = ["AX.tsv", "WNLI.tsv"]

def run_one_step(
    model: DiffusionTransformer,
    input_ids,
    attention_mask,
    accelerator,
    timestamps: float = 0.0,
):
    u_stamps = torch.zeros(input_ids.size(0), device=accelerator.device) + timestamps

    conditioning_mask = torch.ones_like(input_ids, dtype=torch.bool)

    outputs = model(
        timestamps=u_stamps.detach(),
        input_ids=input_ids,
        conditioning_mask=conditioning_mask,
        output_denoised=False,
        drop_conditioning_embeddings=False,
        self_conditioning=False,
        output_loss=False,
        t_max=0,
        weighted_loss=False,
    )
    return outputs


def train_one_epoch(
    model, loader, optimizer, accelerator, num_labels, wandb=None, timestamp=0.0, is_regression: bool = False
):
    predictions = []
    labels = []
    pbar = tqdm(enumerate(loader), total=len(loader), leave=False)
    for idx, batch in pbar:
        optimizer.zero_grad()
        outputs = run_one_step(
            model=model,
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            accelerator=accelerator,
            timestamps=timestamp,
        )
        logits_cls = outputs.logits[:, 0, :num_labels]
        if not is_regression:
            loss = F.cross_entropy(logits_cls, batch["labels"])
        else:
            loss = F.mse_loss(logits_cls.view(-1), batch["labels"])
        accelerator.backward(loss)
        optimizer.step()
        if not is_regression:
            c_predictions = logits_cls.detach().argmax(-1).cpu().tolist()
        else:
            c_predictions = logits_cls.view(-1).detach().cpu().tolist()
        c_labels = batch.labels.cpu().tolist()
        predictions += c_predictions
        labels += c_labels
        pbar.set_description(f"loss: {loss.item():.3f}")
        if wandb is not None:
            wandb.log({"loss": loss.item()})
    return predictions, labels


def run_diffusion(
    idx,
    batch,
    model,
    t_min,
    t_max,
    accelerator,
    optimizer,
    target_token_idx: int = 1,
):
    timestamps = get_t(batch["input_ids"].size(0), idx).to(accelerator.device)
    u_stamps = timestamps * (t_max - t_min) + t_min
    conditioning_mask = torch.ones_like(batch.input_ids)
    conditioning_mask[:, target_token_idx] = False
    optimizer.zero_grad()
    self_cond = True
    zero_cond = False
    outputs = model(
        timestamps=u_stamps.detach(),
        input_ids=batch["input_ids"],
        conditioning_mask=batch["conditioning_mask"],
        output_denoised=False,
        drop_conditioning_embeddings=zero_cond,
        self_conditioning=self_cond,
        output_loss=True,
        attention_mask=batch["attention_mask"],
        t_max=t_max,
    )

    # with torch.no_grad():
    #     ce_targets = batch["input_ids"].detach().clone()
    #     logits_r = outputs.logits.detach().clone()
    #     ce = torch.zeros_like(ent_pred)
    #     for i in range(ent_pred.size(0)):
    #         ce[i] = F.cross_entropy(
    #             input=logits_r[i][~batch["conditioning_mask"][i]],
    #             target=ce_targets[i][~batch["conditioning_mask"][i]],
    #         )
    # tw_loss = F.mse_loss(ent_pred[ce > 0.0001], ce[ce > 0.0001].detach())
    return outputs


def train_one_epoch_diffusion(
    model,
    loader,
    optimizer,
    accelerator,
    num_labels,
    name_to_label,
    wandb=None,
    t_min=1,
    t_max=10,
    target_idx=1,
):
    predictions = []
    labels = []
    pbar = tqdm(enumerate(loader), total=len(loader), leave=False)
    t_idxs = list(name_to_label.keys())
    for idx, batch in pbar:
        optimizer.zero_grad()
        outputs = run_diffusion(
            idx=idx,
            model=model,
            batch=batch,
            accelerator=accelerator,
            t_min=t_min,
            t_max=t_max,
            optimizer=optimizer,
        )
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        c_labels = batch.labels.cpu().tolist()
        predictions += outputs.logits[:, target_idx].detach().cpu()
        norm_predictions = []
        for c_pred in predictions:
            c_cls_logits = np.array([c_pred[c_t_idx] for c_t_idx in t_idxs])
            norm_predictions.append(c_cls_logits.argmax(-1))

        labels += c_labels
        pbar.set_description(f"loss: {loss.item():.3f}")
        if wandb is not None:
            wandb.log({"loss": loss.item()})
    return predictions, labels


@torch.no_grad()
def run_diffusion(model, num_steps, t_max: int = 10, t_min: int = 1):
    pass


@torch.no_grad()
def validate(model, loader, accelerator, num_labels, is_regression: bool = False):
    pbar = tqdm(enumerate(loader), total=len(loader), leave=False)
    predictions = []
    labels = []
    for idx, batch in pbar:
        outputs = run_one_step(
            model=model,
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            accelerator=accelerator,
        )
        logits_cls = outputs.logits[:, 0, :num_labels]
        if not is_regression:
            c_predictions = logits_cls.detach().argmax(-1).cpu().tolist()
        else:
            c_predictions = logits_cls.detach().view(-1).cpu().tolist()
        c_labels = batch.labels.cpu().tolist()
        predictions += c_predictions
        labels += c_labels
    return predictions, labels


def write_test(path, task_name, results):
    for result_key, result in results.items():
        if isinstance(result, float):
            print(result_key + ": " + str(result))
            continue
        with open(
                f"{path}/{task_to_filename[task_name]}", "w"
        ) as writer:
            for index, item in enumerate(result):
                if task_name == "stsb":
                    item = max(0, min(item, 5))
                    writer.write("%d\t%3.3f\n" % (index, item))
                else:
                    item = task_to_labels[task_name][item]
                    writer.write("%d\t%s\n" % (index, item))

def zipdir(filenames, ziph):
    # ziph is zipfile handler
    for file, arcname in zip(filenames, arcnames):
        ziph.write(file, arcname=arcname)
    for d_f in dummy_files:
        ziph.write(d_f)

