import torch as T
import torch.nn as nn
import torch.nn.functional as F
from optimizers import *
from torch.optim import *
from controllers import *
from seqeval.metrics import f1_score
from seqeval.scheme import IOB2



class seq_label_agent:
    def __init__(self, model, config, device):
        self.model = model
        self.parameters = [p for p in model.parameters() if p.requires_grad]
        optimizer = eval(config["optimizer"])

        if "memory_lr" in config:
            grouped_parameters = [
                {'params': [p for n, p in model.named_parameters() if "memory_values" in n and p.requires_grad],
                 'weight_decay': config["weight_decay"], 'lr': config["memory_lr"]},
                {'params': [p for n, p in model.named_parameters() if
                            "memory_values" not in n and p.requires_grad],
                 'weight_decay': config["weight_decay"], 'lr': config["lr"]}]
            self.optimizer = optimizer(grouped_parameters,  # optimizer_grouped_parameters,
                                       lr=config["lr"],
                                       weight_decay=config["weight_decay"])

            if config["different_betas"]:
                self.optimizer = optimizer(grouped_parameters,
                                           lr=config["lr"],
                                           weight_decay=config["weight_decay"],
                                           betas=(0, 0.999),
                                           eps=1e-9)
            else:
                self.optimizer = optimizer(grouped_parameters,  # optimizer_grouped_parameters,
                                           lr=config["lr"],
                                           weight_decay=config["weight_decay"])

        else:
            if config["different_betas"]:
                self.optimizer = optimizer(self.parameters,
                                           lr=config["lr"],
                                           weight_decay=config["weight_decay"],
                                           betas=(0, 0.999),
                                           eps=1e-9)
            else:
                self.optimizer = optimizer(self.parameters,  # optimizer_grouped_parameters,
                                           lr=config["lr"],
                                           weight_decay=config["weight_decay"])

        self.scheduler, self.epoch_level_scheduler = get_scheduler(config, self.optimizer)
        self.config = config
        self.device = device
        self.DataParallel = config["DataParallel"]

        self.optimizer.zero_grad()

    def convert_labels(self, y):
        idx2labels = self.config["idx2labels"]
        new_y = []
        for item in y:
            new_item = [idx2labels[id] for id in item]
            new_y.append(new_item)
        return new_y


    # %%
    def run(self, batch, train=True):

        if train:
            self.model = self.model.train()
        else:
            self.model = self.model.eval()

        if not self.DataParallel:
            batch["sequences_vec"] = batch["sequences_vec"].to(self.device)
            batch["char_sequences_vec"] = batch["char_sequences_vec"].to(self.device)
            batch["feats"] = batch["feats"].to(self.device)
            batch["labels_vec"] = batch["labels_vec"].to(self.device)
            batch["input_masks"] = batch["input_masks"].to(self.device)


        output_dict = self.model(batch)

        loss = output_dict["loss"]
        predictions = output_dict["predictions"]
        #predictions = predictions.detach().cpu().numpy().tolist()
        labels = batch["labels"]

        predictions = self.convert_labels(predictions)
        labels = self.convert_labels(labels)

        metrics = {}
        metrics["loss"] = loss.item()
        metrics["F1"] = f1_score(labels, predictions, scheme=IOB2)

        items = {"display_items": {"sequences": batch["sequences"],
                                   "predictions": predictions,
                                   "labels": labels},
                 "loss": loss,
                 "metrics": metrics}

        return items

    # %%
    def backward(self, loss):
        loss.backward()

    # %%
    def step(self):
        if self.config["max_grad_norm"] is not None:
            T.nn.utils.clip_grad_norm_(self.model.parameters(), self.config["max_grad_norm"])
        self.optimizer.step()
        self.optimizer.zero_grad()
        if (not self.epoch_level_scheduler) and self.scheduler is not None:
            self.scheduler.step()

        self.config["current_lr"] = self.optimizer.param_groups[-1]["lr"]

