import numpy as np
import torch
import shutil

from tqdm import tqdm
import pandas as pd
from simpletransformers.config.model_args import ClassificationArgs

def _load_model_args(input_dir):
    args = ClassificationArgs()
    args.load(input_dir)
    return args


class EarlyStopping:
    def __init__(self, accelerator=None, patience=7, verbose=False, delta=0, save_mode=True):
        self.accelerator = accelerator
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.save_mode = save_mode

    def __call__(self, val_loss, model, path):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            if self.save_mode:
                self.save_checkpoint(val_loss, model, path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.accelerator is None:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            else:
                self.accelerator.print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            if self.save_mode:
                self.save_checkpoint(val_loss, model, path)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, path):
        if self.verbose:
            if self.accelerator is not None:
                self.accelerator.print(
                    f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
            else:
                print(
                    f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')

        if self.accelerator is not None:
            model = self.accelerator.unwrap_model(model)
            torch.save(model.state_dict(), path + '/' + 'checkpoint')
        else:
            torch.save(model.state_dict(), path + '/' + 'checkpoint')
        self.val_loss_min = val_loss


class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


class StandardScaler():
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def transform(self, data):
        return (data - self.mean) / self.std

    def inverse_transform(self, data):
        return (data * self.std) + self.mean


def vali(args, accelerator, model, vali_loader, criterion, ParrotModel):
    total_loss = []
    model.eval()
    logists=[]
    true_labels=[]

    eval_data = []
    top_k = max((20,))
    condition_n = 5

    with torch.no_grad():
        for i, batch in tqdm(enumerate(vali_loader)):
            inputs_parrot = ParrotModel._get_inputs_dict(batch, accelerator.device)
            if 'graph' not in inputs_parrot.keys():
                inputs_parrot['graph'] = None

            ParrotModel.model = ParrotModel.model.to(accelerator.device)
            if ParrotModel.args.fp16:
                from torch.cuda import amp
                with amp.autocast():
                    outputs = ParrotModel.model(**inputs_parrot)
            else:
                outputs = ParrotModel.model(**inputs_parrot)
            memory_unpool = outputs[-1]

            input_label = torch.Tensor([input_ for input_ in inputs_parrot['labels'].tolist()])
            if args.use_graph:
                inputs = {
                    "input_emb": memory_unpool.to(torch.bfloat16).to(accelerator.device),
                    "labels": input_label.to(torch.bfloat16).to(accelerator.device),
                    "paragraph_text":inputs_parrot['corpus_text'],
                    "rxn_text":inputs_parrot['rxn_text'],
                    "graph":inputs_parrot['graph'].to(torch.bfloat16).to(accelerator.device)
                }
            else:
                inputs = {
                    "input_emb": memory_unpool.to(torch.bfloat16).to(accelerator.device),
                    "labels": input_label.to(torch.bfloat16).to(accelerator.device),
                    "paragraph_text":inputs_parrot['corpus_text'],
                    "rxn_text":inputs_parrot['rxn_text']
                }

            if args.use_amp:
                with torch.cuda.amp.autocast():
                    if args.output_attention:
                        outputs = model(inputs)[0]
                    else:
                        outputs = model(inputs)
            else:
                if args.output_attention:
                    outputs = model(inputs)[0]
                else:
                    outputs = model(inputs)

            labels_out = inputs["labels"][:, 1:-1]
            loss = criterion(outputs.reshape(-1, outputs.shape[-1]).detach().float(),
                                labels_out.reshape(-1).detach().long())
            logists.append(outputs.reshape(-1, outputs.shape[-1]).float())
            true_labels.append(labels_out.reshape(-1).long())
            total_loss.append(loss.item())

    total_loss = np.average(total_loss)
    predict_labels=torch.argmax(torch.cat(logists,dim=0),dim=1)
    true_labels=torch.cat(true_labels,dim=0)
    assert predict_labels.shape==true_labels.shape
    correct = (predict_labels == true_labels).sum().item()
    accuracy = correct / len(true_labels)
    accelerator.print(
        "Avg Test Accuracy: {0:.7f} ".format(accuracy),  f'for totally {predict_labels.shape} conditions')
    return total_loss, accuracy

