import os
import sys
import math
from supp.proxy.attack_utils import run_gbda, run_pez
import numpy as np
import torch
import argparse
from itertools import product
from torch.utils.data import DataLoader
from transformers.modeling_outputs import SequenceClassifierOutput, SequenceClassifierOutputWithPast
from transformers import (
    get_scheduler
)
from accelerate import Accelerator
from accelerate.utils import DummyScheduler
from tqdm import tqdm

sys.path.append("../../reward")
from utils import load_model, load_process_data, get_tokenizer

NO_HARD_SET = [
    'power', 'harmless', 'hh', 'shp', 'paired_moral_stories', 'cm_dialog', 'virtue_dialog', "oasst", "oa_safety",
    "hh_shp_oa_train80", "hh_shp_oa_train20", "hh_shp_oa_train80_pc1", "hh_shp_oa_train80_pc10",
]

ATTACK_FUNC = {
    "pez": run_pez,
    "gbda": run_gbda,
    "none": None,
}
def main(args):
    test_hard_accs, test_accs = [], []
    data_dir = os.path.abspath("/data/private_models/xx_models/data/ranking_datasets")
    for run in range(args.nruns):
        model, optimizer = load_model(args, load_path=args.load_path, accelerator=accelerator)

        train_sets, test_loaders, hard_test_loaders = [], [], []
        with accelerator.main_process_first():
            for task in args.train_tasks:
                train_sets.append(load_process_data(args, data_dir, task, "train", args.model))
            train_data = torch.utils.data.ConcatDataset(train_sets)
            # sample a subset
            if args.debug_train_samples:
                train_data = torch.utils.data.Subset(train_data, np.random.choice(len(train_data), len(train_data) // 10, replace=False))


            accelerator.print(f"Total train samples: {len(train_data)}")
            train_dataloader = DataLoader(train_data, batch_size=args.batch_size // 2, shuffle=True)
            
            for task in args.test_tasks:
                test_data = load_process_data(args, data_dir, task, "test", args.model)
                # test_data = torch.utils.data.Subset(test_data, np.random.choice(len(test_data), len(test_data) // 10, replace=False))

                test_loaders.append(DataLoader(test_data, batch_size=args.batch_size // 2, shuffle=False))
                if task not in NO_HARD_SET:
                    test_hard_data = load_process_data(args, data_dir, task, "test_hard", args.model)
                    hard_test_loaders.append(DataLoader(test_hard_data, batch_size=args.batch_size // 2, shuffle=False))
                else:
                    hard_test_loaders.append([])
        
        num_update_steps_per_epoch = len(train_dataloader)
        args.max_train_steps = args.nepochs * num_update_steps_per_epoch
        args.num_warmup_steps = int(args.max_train_steps * args.warmup_ratio)
        overrode_max_train_steps = True

        if (
            accelerator.state.deepspeed_plugin is None
            or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
        ):
            lr_scheduler = get_scheduler(
                name=args.lr_scheduler_type,
                optimizer=optimizer,
                num_warmup_steps=args.num_warmup_steps,
                num_training_steps=args.max_train_steps,
            )
        else:
            lr_scheduler = DummyScheduler(
                optimizer, total_num_steps=args.max_train_steps, warmup_num_steps=args.num_warmup_steps
            )
        
        model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
            model, optimizer, train_dataloader, lr_scheduler
        )

        p_test_loaders = accelerator.prepare(*test_loaders)
        if len(test_loaders) == 1:
            test_loaders = [p_test_loaders]
        else:
            test_loaders = p_test_loaders
        p_hard_test_loaders = accelerator.prepare(*hard_test_loaders)
        if len(hard_test_loaders) == 1:
            hard_test_loaders = [p_hard_test_loaders]
        else:
            hard_test_loaders = p_hard_test_loaders

        num_update_steps_per_epoch = len(train_dataloader)
        if overrode_max_train_steps:
            args.max_train_steps = args.nepochs * num_update_steps_per_epoch
        # Afterwards we recalculate our number of training epochs
        args.nepochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
        # adv_test_acc = evaluate(model, test_loaders[0], 1, adv=True, num_adv_steps=args.num_adv_steps)

        if args.do_train:
            accelerator.print("sanity check test acc")
            test_acc = evaluate(model, test_loaders[0], 1)
            for epoch in range(1, args.nepochs + 1):
                accelerator.print('Epoch', epoch)
                train(model, optimizer, train_dataloader, lr_scheduler, epoch, verbose=args.verbose, num_adv_steps=args.num_adv_steps)
                for task, loader, hard_loader in zip(args.test_tasks, test_loaders, hard_test_loaders):
                    accelerator.print(f"{task} acc")
                    group = 1
                    test_acc = evaluate(model, loader, group)
                    if run_attack is not None:
                        adv_test_acc = evaluate(model, loader, group, adv=True, num_adv_steps=args.num_adv_steps)
                    if task not in NO_HARD_SET:
                        accelerator.print(f"{task} hard acc")
                        test_hard_acc = evaluate(model, hard_loader, group)
        elif args.do_eval and run_attack is not None:
            for task, loader in zip(args.test_tasks,test_loaders):
                print("====Eval Task: ", task, "========")
                test_acc = evaluate(model, loader, 1)
            for task, loader in zip(args.test_tasks,test_loaders):
                print("====Eval Adv Task: ", task, "========")
                adv_test_acc = evaluate(model, loader, 1, adv=True, num_adv_steps=args.num_adv_steps)

    if args.save and accelerator.is_main_process:
        model_suffix = args.model.split("/")[-1] # discard organization and slash
        if not args.out_file:
            save_path = "{}_{}_{}_{}_{}.pkl".format(model_suffix, "_".join(args.train_tasks + "gbda"), args.learning_rate, args.batch_size, args.nepochs)
        else:
            save_path = args.out_file
        print("SAVING to", save_path)
        torch.save(model.module.state_dict(), save_path)

    return 0,0

def flatten(tensor):
    tensor = torch.cat([tensor[:, 0], tensor[:, 1]])
    return tensor

def unflatten(tensor):
    tensor = torch.stack([tensor[:tensor.shape[0] // 2], tensor[tensor.shape[0] // 2:]], axis=1)
    return tensor

def multiproc(inputs):
    a = run_attack(inputs[0],inputs[1], inputs[2], inputs[3], num_optim_tokens=inputs[4], num_steps=inputs[5], eval_steps=inputs[6], mode=inputs[7], maximize=inputs[8])
    return a

def train(model, optimizer, train_dataloader, lr_scheduler, epoch, log_interval=100, verbose=False, num_adv_steps=100):
    # Set model to training mode
    model.train()
    criterion = torch.nn.BCEWithLogitsLoss()
    ntrain_steps = len(train_dataloader)

    # Loop over each batch from the training set
    for step, batch in tqdm(enumerate(train_dataloader), total=ntrain_steps):

        # Unpack the inputs from our dataloader
        b_input_ids, b_input_mask, b_labels = batch

        if run_attack is not None:
            a0 = run_attack(b_input_ids[:, 0], b_input_mask[:, 0], model, tokenizer, num_optim_tokens=args.num_optim_tokens, num_steps=num_adv_steps, eval_steps=num_adv_steps, mode='random_replace', maximize=False)
            a1 = run_attack(b_input_ids[:, 1], b_input_mask[:, 1], model, tokenizer, num_optim_tokens=args.num_optim_tokens, num_steps=num_adv_steps, eval_steps=num_adv_steps, mode='random_replace', maximize=True)
            b_input_ids[:, 0] = a0
            b_input_ids[:, 1] = a1
        
        b_input_ids = flatten(b_input_ids)
        b_input_mask = flatten(b_input_mask)

        with accelerator.accumulate(model):
            output = model(b_input_ids, attention_mask=b_input_mask)[0]
            output = unflatten(output)
            diffs = output[:, 0] - output[:, 1]
            loss = criterion(diffs.squeeze(dim=1), torch.ones(diffs.shape[0]).cuda())
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        if step % log_interval == 0 and step > 0 and verbose:
            accelerator.print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, step, ntrain_steps, 100. * step / ntrain_steps, loss))

def evaluate(model, dataloader, group=1, adv=False, num_adv_steps=100):
    model.eval()
    cors = []

    for step, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Evaluating {'adv' if adv else 'normal'}"):

        # Unpack the inputs from our dataloader
        b_input_ids, b_input_mask, b_labels = batch

        if adv:
            b_input_ids[:, 0] = run_attack(b_input_ids[:, 0], b_input_mask[:, 0], model, tokenizer, num_optim_tokens=args.num_optim_tokens, num_steps=num_adv_steps, eval_steps=num_adv_steps, mode='random_replace', maximize=False)
            b_input_ids[:, 1] = run_attack(b_input_ids[:, 1], b_input_mask[:, 1], model, tokenizer, num_optim_tokens=args.num_optim_tokens, num_steps=num_adv_steps, eval_steps=num_adv_steps, mode='random_replace', maximize=True)

        # reshape
        b_input_ids = flatten(b_input_ids)
        b_input_mask = flatten(b_input_mask)

        # Forward pass
        with torch.no_grad():
            output = model(b_input_ids, attention_mask=b_input_mask)
        if isinstance(output, SequenceClassifierOutput) or isinstance(output, SequenceClassifierOutputWithPast): # this is for deberta, etc. outputs only; not flan
            output = output[0]  # dim 1
        output = unflatten(output)
        diffs = output[:, 0] - output[:, 1]
        diffs = diffs.squeeze(dim=1).float()
        diffs = accelerator.gather(diffs).detach().cpu().numpy()
        cors.append(diffs > 0)

    cors = np.concatenate(cors)
    cors = cors[:len(dataloader.dataset)]
    acc = np.mean(cors)
    if group > 1:
        em_sums = [np.sum(cors[group*i:group*(i+1)]) for i in range(len(cors) // group)]
        em_cors = [em_sums[i] == group for i in range(len(em_sums))]
        acc = np.mean(em_cors)

    accelerator.print('{} Test Acc {:.3f}'.format('Normal' if not adv else 'Adv', acc))
    return acc


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", "-m", type=str, default="bert-base-uncased")
    parser.add_argument("--load_path", type=str, default=None)
    parser.add_argument("--train_tasks", nargs="+", default=["util"])
    parser.add_argument("--test_tasks", nargs="+", default=["util"])
    parser.add_argument("--do_train", action="store_true")
    parser.add_argument("--do_eval", action="store_true")
    parser.add_argument("--ngpus", "-n", type=int, default=2)
    parser.add_argument("--nepochs", "-e", type=int, default=2)
    parser.add_argument("--batch_size", "-b", type=int, default=16)
    parser.add_argument("--max_length", "-t", type=int, default=64)
    parser.add_argument("--weight_decay", "-w", type=float, default=0.01)
    parser.add_argument("--learning_rate", "-l", type=float, default=1e-5)
    parser.add_argument(
        "--lr_scheduler_type", default="cosine",
        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
    )
    parser.add_argument("--warmup_ratio", type=float, default=0.03)
    parser.add_argument("--verbose", "-v", action="store_true")
    parser.add_argument("--nruns", "-r", type=int, default=1)
    parser.add_argument("--save", "-s", action="store_true")
    parser.add_argument("--out_file", type=str, default="")
    parser.add_argument("--custom_tokens", "-ct", nargs="+")
    parser.add_argument("--custom_prefix", type=str, default="")
    parser.add_argument("--add_prefix", "-p", action="store_true")
    parser.add_argument("--dropout", "-d", type=float, default=0.0)
    parser.add_argument("--freeze_base", "-fb", action="store_true")
    parser.add_argument("--debug_train_samples", action="store_true")
    parser.add_argument("--adv_attack", type=str, default="pez", choices=["pez", "gbda", "none"])
    parser.add_argument("--num_adv_steps", type=int, default=100)
    parser.add_argument("--num_optim_tokens", type=int, default=64)
    parser.add_argument("--adv_ratio", type=float, default=1.0)
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()

    accelerator = Accelerator()
    accelerator.print("REPLICAS: ", accelerator.num_processes)
    
    tokenizer = get_tokenizer(args.model)

    run_attack = ATTACK_FUNC[args.adv_attack.lower()]
    if run_attack is None:
        accelerator.print("WARNING: Training without adv attack")

    np.random.seed(args.seed)
    main(args)




