from __future__ import division

import copy
import math
import os.path
import argparse
import random
import tqdm
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score
import torch
import numpy as np
from torch.optim.lr_scheduler import MultiStepLR
from torch.optim.sgd import SGD
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW, Adam
from torch.utils.data import DataLoader
import torch.nn.functional as F
from myDataLoader import bucket_dataset
import random
from PGDModel import *
from utils import *

def freeze_pretrain(model):
    for n, p in model.named_parameters():
        if 'bert' in n or 'embed' in n: p.requires_grad = False


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)


def get_accuracy(preds, labels):
    correct = 0
    for p, l in zip(preds, labels):
        if p == l: correct += 1
    return correct / len(labels)


def get_test_perform(args, data_path, model, train_error=False):
    task2metric = {'SST': 'acc', 'QNLI': 'acc', }

    metric = task2metric[args.task_name]

    if train_error:
        test_dataset = bucket_dataset(args, args.train_data)
    else:
        test_dataset = bucket_dataset(args, data_path)

    test_dataloader = DataLoader(dataset=test_dataset, batch_size=16, collate_fn=lambda x: x)

    preds, true_labels = [], []
    with torch.no_grad():
        for batch in test_dataloader:
            # input_ids, token_type_ids, attention_mask, labels = get_batch2(batch)

            _, logits, labels, loss = model(batch)
            pred_softmax = torch.nn.functional.softmax(logits, dim=1)
            pred = np.argmax(pred_softmax.detach().cpu().numpy(), axis=1)
            preds.extend(pred)
            true_label = labels.detach().cpu().numpy()
            true_labels.extend(true_label)

    return get_accuracy(preds, true_labels)


def auto_tuning(args):
    set_seed(args)
    model = bert_clf(args).to(args.device)
    freeze_pretrain(model)

    others = []
    w0, p, layers = initialization(model)

    b = nn.Parameter(torch.ones(1, device=args.device) * p.data.mean(), requires_grad=True)

    best_performance = 0.0

    min_gamma = .5
    max_gamma = 10
    prior_list, K_list = compute_K_sample_transformer(args, model, args.train_data, min_gamma, max_gamma)
    model1 = copy.deepcopy(model)
    w0, p, layers = initialization(model1)

    b = nn.Parameter(torch.ones(1, device=args.device) * p.data.mean(), requires_grad=True)

    paramters = [p for n, p in model1.named_parameters() if p.requires_grad]
    opt1 = AdamW(paramters, lr=1e-5)
    opt2 = AdamW([p], lr=args.lr)
    opt3 = AdamW([b], lr=args.lr)

    if args.refine_gamma:

        for epoch in range(50):
            model1.train()
            train_dataset = bucket_dataset(args, args.train_data)
            train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=lambda x: x)
            for batch_idx, batch in enumerate(train_dataloader):
                opt1.zero_grad()

                opt2.zero_grad()
                opt3.zero_grad()

                # noise injection and ||w-w0||^2
                wdecay = weight_decay(model1, w0)
                noises, noises_scaled = noise_injection(model1, p)

                # loss 1
                _, _, _, loss1 = model1(batch)

                # K = 0

                kl = get_kl_term_with_b(wdecay, p, b)
                # K = fun_K_auto_new(torch.exp(b), prior_list, K_list)
                gamma1 = fun_K_auto_new(torch.exp(b), prior_list, K_list) ** (-1) * (
                        2 * (kl + 10) / args.train_size / 3) ** 0.5
                gamma1 = torch.clip(gamma1, max=max_gamma, min=min_gamma)
                loss2 = 3 * fun_K_auto_new(torch.exp(b), prior_list, K_list) ** 2 * gamma1 / 2 + (
                        kl + 10) / args.train_size / gamma1
                # print(gamma1)

                # backward
                loss1.backward(retain_graph=True)
                if epoch < args.shift:
                    #  if epoch < 0:
                    #      kl_term_backward_mean(loss2, model1, p, noises)
                    #  else:
                    kl_term_backward(loss2, model1, p, noises)

                # remove noises
                rm_injected_noises(model1, noises_scaled)

                opt1.step()

                opt2.step()
                opt3.step()

        min_gamma = gamma1.data
        max_gamma = min_gamma
        # print(max_gamma,min_gamma)
        prior_list, K_list = compute_K_sample_transformer(args, model, args.train_data, min_gamma.detach().cpu(),
                                                          max_gamma.detach().cpu())

    w0, p, layers = initialization(model)

    b = nn.Parameter(torch.ones(1, device=args.device) * p.data.mean(), requires_grad=True)

    paramters = [p for n, p in model.named_parameters() if p.requires_grad]
    opt1 = AdamW(paramters, lr=args.lr)
    opt2 = AdamW([p], lr=args.lr)
    opt3 = AdamW([b], lr=args.lr)
    for epoch in range(args.max_epoch):
        model.train()
        train_dataset = bucket_dataset(args, args.train_data)
        train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=lambda x: x)
        for batch_idx, batch in enumerate(train_dataloader):

            opt1.zero_grad()

            opt2.zero_grad()
            opt3.zero_grad()
            # noise injection and ||w-w0||^2
            wdecay = weight_decay(model, w0)
            noises, noises_scaled = noise_injection(model, p)

            _, _, _, loss1 = model(batch)

            if epoch < args.shift:
                kl = get_kl_term_with_b(wdecay, p, b)
                K = fun_K_auto_new(torch.exp(b), prior_list, K_list)

                gamma1 = fun_K_auto_new(torch.exp(b), prior_list, K_list) ** (-1) * (
                        2 * (kl + 10) / args.train_size / 3) ** 0.5
                gamma1 = torch.clip(gamma1, max=max_gamma, min=min_gamma)

                loss2 = 3 * fun_K_auto_new(torch.exp(b), prior_list, K_list) ** 2 * gamma1 / 2 + (
                        kl + 10) / args.train_size / gamma1
            # print(gamma1)
            else:
                loss2 = 0 * loss1

            # backward
            loss1.backward(retain_graph=True)
            if epoch < args.shift:
                kl_term_backward(loss2, model, p, noises)

            # remove noises
            rm_injected_noises(model, noises_scaled)

            opt1.step()

            if epoch < args.shift:
                opt2.step()
                opt3.step()

            others.append([p.mean().cpu().item()])

def adaptor_tuning(args):

    set_seed(args)
    model = bert_clf(args).to(args.device)

    frezee_pretrain(model)
    paramters = [p for n, p in model.named_parameters() if p.requires_grad]
    opt = AdamW(paramters, lr=args.lr, weight_decay=args.weight_decay)
    if args.OPTIM == "SGD": opt = SGD(paramters, lr=args.lr, weight_decay=args.weight_decay)
    best_performance = 0.0
    cv_idx = args.train_data.split("/")[-1].split(".")[1]
    for epoch in range(args.max_epoch):
        model.train()
        train_dataset = bucket_dataset(args, args.train_data)
        train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=lambda x: x)
        for batch_idx, batch in enumerate(train_dataloader):
            opt.zero_grad()
            noise_list = []
            if args.noise_level > 0.0:
                for name, param in model.named_parameters():
                    if not param.requires_grad: continue
                    noise = torch.randn(param.data.size()).to(args.device) * float(args.noise_level)
                    param.data += noise
                    noise_list.append(noise)
            # loss 1
            _, _, _, loss1 = model(batch)
            # backward
            loss1.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            opt.step()

            if args.noise_level > 0.0 and len(noise_list) > 0:
                for name, param in model.named_parameters():
                    if not param.requires_grad: continue
                    param.data -= noise_list[0]
                    noise_list = noise_list[1:]

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--task_name", default="SST", type=str)
    parser.add_argument("--lr", default=0.1, type=float)
    parser.add_argument("--weight_decay", default=0.0, type=float)
    parser.add_argument("--max_epoch", default=250, type=int)
    parser.add_argument("--attention_dropout", default=0.5, type=float)
    parser.add_argument("--dropout", default=0.0, type=float)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument("--pair_input", type=str, default="False")
    parser.add_argument("--INPUT_IDS", type=str, default='input_ids')
    parser.add_argument("--ATTENTION_MASK", type=str, default='attention_mask')
    parser.add_argument("--TOKEN_TYPE_IDS", type=str, default="token_type_ids")
    parser.add_argument("--LABEL", type=str, default='label')
    parser.add_argument('--TEXT', type=str, default='text')
    parser.add_argument("--max_length", type=int, default=128)
    parser.add_argument("--device", default=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
    parser.add_argument("--train_data", type=str, default="./data/SST/train.txt")
    parser.add_argument("--test_data", type=str, default="./data/SST/test.txt")
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--num_labels", type=int, default=2)
    parser.add_argument('--train_size', type=int, default=100)
    parser.add_argument('--valid_size', type=int, default=20)
    parser.add_argument("--noise_level", type=float, default=0.0)
    parser.add_argument("--shift", default=50, type=int, help="The number of epochs to shift to noise injection.")
    parser.add_argument("--K", default=0.03, type=float, help="The default k for loss 2.")
    parser.add_argument("--OPTIM", default='ADAMW', type=str)
    parser.add_argument("--refine_gamma", default=0, type=int)
    parser.add_argument("--method",default='ours',type=str)

    args = parser.parse_args()
    if args.method == "baseline":
        adaptor_tuning(args)
    else:
        auto_tuning(args)

    
