import torch
import pandas as pd
from transformers import BertModel, AutoTokenizer, BertForSequenceClassification
from torch.utils.data import Dataset, DataLoader
import os
from collections import Counter
from tqdm import tqdm
import torch.nn as nn
from torch.optim import AdamW
from sklearn.metrics import f1_score
import os
import numpy as np
from utils.dataloader import TextDataset, get_label_dict
from utils.utils import get_baseline_optimizer, get_pmixup_optimizer
import torch.nn.functional as F
from models.tmix import *
from transformers.utils import logging

logging.set_verbosity(40)


def pmixup_evaluate(model, val_dataloader):
    total = 0
    num_corrects = 0
    val_loss = 0
    y_pred = []
    y_true = []
    with torch.no_grad():
        model.eval()
        print("------Evaluation------")
        for step, batch in enumerate(tqdm(val_dataloader)):
            inputs = batch
            inputs['input_ids'] = inputs['input_ids'].squeeze(1)

            real_inputs = {}
            real_inputs['x'] = inputs['input_ids'].squeeze(1).cuda()
            labels = inputs['labels']
            output = model(**real_inputs)
            preds = torch.argmax(output, dim=-1)
            y_pred += preds.cpu()
            y_true += labels.cpu()
            for i in range(len(preds)):
                total += 1
                if preds[i] == labels[i]:
                    num_corrects += 1
    val_f1 = f1_score(y_true, y_pred, average="macro")
    return val_loss / len(val_dataloader), (num_corrects / total), val_f1


def baseline_evaluate(model, val_dataloader):
    total = 0
    num_corrects = 0
    val_loss = 0
    y_pred = []
    y_true = []
    with torch.no_grad():
        model.eval()
        print("------Evaluation------")
        for step, batch in enumerate(tqdm(val_dataloader)):
            inputs = batch
            inputs['input_ids'] = inputs['input_ids'].squeeze(1)
            labels = inputs['labels']
            output = model(**inputs)
            preds = torch.argmax(output.logits, dim=-1)
            y_pred += preds.cpu()
            y_true += labels.cpu()
            for i in range(len(preds)):
                total += 1
                if preds[i] == labels[i]:
                    num_corrects += 1

    val_f1 = f1_score(y_true, y_pred, average="macro")
    return val_loss / len(val_dataloader), (num_corrects / total), val_f1


def run_baseline(args, label_dict, dataframe, dataset_name, val_df, feature, condition=None, text_column='text',
                 label_column='label'):
    device = torch.device("cuda" if torch.cuda.is_available() else cpu)

    train_dataset = TextDataset(dataframe, label_dict, text_column, label_column, args.max_length)
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)

    val_dataset = TextDataset(val_df, label_dict, text_column, label_column, args.max_length)
    val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False)

    model = BertForSequenceClassification.from_pretrained(args.model_name, num_labels=len(label_dict))
    model = torch.nn.DataParallel(model).to(device)
    optimizer = get_baseline_optimizer(model, args.lr)
    scaler = torch.cuda.amp.GradScaler()
    best_acc = 0
    for epoch in range(args.num_epochs):
        model.train()
        model.zero_grad()
        epoch_loss = 0
        for i, batch in enumerate(tqdm(train_dataloader)):
            batch['input_ids'] = batch['input_ids'].squeeze(1)
            with torch.cuda.amp.autocast():
                output = model(**batch)
                loss = output['loss']
            epoch_loss += loss.mean().item()
            scaler.scale(loss.mean()).backward()
            scaler.step(optimizer)
            scaler.update()
            model.zero_grad()

        val_loss, val_acc, val_f1 = baseline_evaluate(model, val_dataloader)
        if best_acc < val_acc:
            best_acc = val_acc

            if not os.path.exists(f'../model_weights/{dataset_name}'):
                os.makedirs(f"../model_weights/{dataset_name}")
            if not feature:
                torch.save(model.state_dict(), f'../model_weights/{dataset_name}/baseline_model.pt')
            else:
                if condition:
                    torch.save(model.state_dict(), f'../model_weights/{dataset_name}/model_{feature}_{condition}.pt')


def run_pmixup(args, label_dict, train_df, dataset, val_df, sample_size, feature, text_column="text",
               label_column="label"):
    device = torch.device("cuda" if torch.cuda.is_available() else cpu)

    train_dataset = TextDataset(train_df, label_dict, text_column, label_column, args.max_length)
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)

    model = MixText(num_labels=len(label_dict), mix_option=True).cuda()
    model = nn.DataParallel(model)

    optimizer = get_pmixup_optimizer(model, args.lr)

    val_dataset = TextDataset(val_df, label_dict, text_column, label_column, args.max_length)
    val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False)

    mix_layer_set = [7, 9, 12]
    best_acc = 0
    for epoch in range(args.num_epochs):
        alpha = 16
        l = np.random.beta(alpha, alpha)
        l = max(l, 1 - l)
        mix_layer = np.random.choice(mix_layer_set, 1)[0]
        mix_layer = mix_layer - 1
        model.train()
        epoch_loss = 0
        for i, batch in enumerate(tqdm(train_dataloader)):
            idx = torch.randperm(batch['input_ids'].size(0))
            inputs1 = batch['input_ids']
            inputs2 = torch.index_select(inputs1.cpu(), dim=0, index=idx)
            inputs = {}
            inputs['x'] = inputs1.squeeze(1).to(device)
            inputs['x2'] = inputs2.squeeze(1).to(device)
            inputs['l'] = l
            outputs = model(**inputs, mix_layer=mix_layer)

            real_labs = batch['labels'].cuda()

            targets_x = torch.zeros(batch['input_ids'].size(0), torch.tensor(len(label_dict))).cuda().scatter_(1,
                                                                                                               real_labs.cuda().view(
                                                                                                                   -1,
                                                                                                                   1),
                                                                                                               1)

            out_labs = torch.index_select(targets_x, dim=0, index=idx.cuda())

            mixed_target = l * targets_x + (1 - l) * out_labs

            Lx = -torch.mean(torch.sum(F.log_softmax(outputs, dim=1) * mixed_target.cuda(), dim=1))
            loss = Lx

            optimizer.zero_grad()
            loss.mean().backward()
            optimizer.step()
        val_loss, val_acc, val_f1 = pmixup_evaluate(model, val_dataloader)
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), f'../model_weights/{dataset}/pmixup_model_{feature}_{sample_size}.pt')
