from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForMaskedLM, DataCollatorForTokenClassification, set_seed
import torch
from torch.utils.data import DataLoader 
import time 
import wandb 
import os
from abc import ABC, abstractmethod
import argparse
import json
import math
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

os.environ["WANDB_MODE"] = "online"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.set_float32_matmul_precision("high")

parent_parser = argparse.ArgumentParser(add_help=False)
parent_parser.add_argument("--config")
args, _ = parent_parser.parse_known_args()

config = {}
if args.config:
    with open(args.config) as f:
        config = json.load(f)

parser = argparse.ArgumentParser(parents=[parent_parser])   
parser.add_argument("--model_name", help="Model", 
                    choices=["bert-base-cased","bert-large-cased", "roberta-base", "roberta-large"],
                    default="roberta-large")
parser.add_argument("--dataset", help="Dataset name",
                    choices=["mrpc", "qqp"], default="mrpc")
parser.add_argument("--opt", help="Optimizer", 
                    choices=["SGD", "NCRS", "RGF"], default="SGD")
parser.add_argument("--epochs", help="Number of epochs",
                    default=5, type=int)
parser.add_argument("--train_batch_size", help="Batch size of training data",
                    default=8, type=int)
parser.add_argument("--grad_accum_steps", help="Gradient accumulation steps",
                    default=1, type=int)
parser.add_argument("--val_batch_size", help="Batch size of validation data",
                    default=16, type=int)
parser.add_argument("--logging_steps", help="Number of update steps between two logs",
                    default=10, type=int)
parser.add_argument("--eval_steps", help="Number of evaluation steps",
                    default=100, type=int)
parser.add_argument("--lr", help="Maximum learning rate",
                    default=1e-5, type=float)
parser.add_argument("--mu", help="Perturbation parameter",
                    default=1e-3, type=float)
parser.add_argument("--warmup_ratio", help="Ratio of total number of steps for a linear warmup",
                    default=0.06, type=float)

parser.set_defaults(**config)
args = parser.parse_args()

init_seed = 42
set_seed(init_seed)
device = "cuda" if torch.cuda.is_available() else "cpu" 
print("Running on device:", device) 

# DDP setup
ddp = int(os.environ.get('RANK', -1)) != -1
if ddp:
    ddp_rank = int(os.environ['RANK'])
    ddp_local_rank = int(os.environ['LOCAL_RANK'])
    ddp_world_size = int(os.environ['WORLD_SIZE'])
    device = torch.device("cuda", ddp_local_rank)
    torch.cuda.set_device(device)
    dist.init_process_group(backend="nccl", device_id=device)
    dist.barrier()
    master_process = ddp_rank == 0
else:
    master_process = True

# Prompt configuration
config = {
    "template": "{s1} {mask} , {s2}",
    "label_words": [" No", " Yes"] 
}

# Load Model
raw_model = AutoModelForMaskedLM.from_pretrained(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
raw_model.to(device) 

# Extracting token ids of labeling words (Yes/No)
label_word_ids = []
for word in config["label_words"]:
    ids = tokenizer.encode(word, add_special_tokens=False)
    label_word_ids.append(ids[0])

# Loading and preprocessing the dataset
if args.dataset == "mrpc":
    dataset = load_dataset("glue", "mrpc")
elif args.dataset == "qqp":
    dataset = load_dataset("glue", "qqp")
    dataset = dataset.rename_columns({"question1": "sentence1", "question2": "sentence2"})

dataset["validation"] = dataset["validation"].select(range(min(500, dataset["validation"].num_rows)))

def preprocess(row):
    # Tokenizing and preparing labels
    full_text = config["template"].format(s1=row['sentence1'], s2=row['sentence2'], mask=tokenizer.mask_token)
    label_idx = row['label']
    input = tokenizer(full_text)
    labels = [-100] * len(input["input_ids"])
    mask_index = input["input_ids"].index(tokenizer.mask_token_id)
    labels[mask_index] = label_word_ids[label_idx]
    input["labels"] = labels

    return input

dataset = dataset.map(preprocess).remove_columns(["idx", "sentence1", "sentence2", "label"])

if master_process:
    print(f"Training examples: {len(dataset['train'])}")

collate_fn = DataCollatorForTokenClassification(tokenizer=tokenizer)

if ddp:
    sampler_val = DistributedSampler(dataset["validation"], shuffle=False)
    sampler_train = DistributedSampler(dataset["train"], shuffle=False)
else:
    sampler_val, sampler_train = None, None

loader_train = DataLoader(dataset["train"], batch_size=args.train_batch_size, collate_fn=collate_fn, shuffle=False, sampler=sampler_train)
loader_val = DataLoader(dataset["validation"], batch_size=args.val_batch_size, collate_fn=collate_fn, shuffle=False, sampler=sampler_val)

# Helper functions
def gen_batch(loader):
    while True:
        for batch in loader:
            yield batch
it_train = gen_batch(loader_train)

def evaluate():
    # Calculate validation loss and accuracy
    model.eval()
    total_loss = 0.0
    total_correct = 0.0
    total_samples = 0
    
    candidate_ids = torch.tensor(label_word_ids, device=device)

    with torch.inference_mode():
        for batch in loader_val:
            batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}   
            outputs = model(**batch)

            batch_size = batch["input_ids"].size(0)
            total_loss += outputs.loss.item() * batch_size 
        
            preds = torch.argmax(outputs.logits[(batch["input_ids"] == tokenizer.mask_token_id)][:, candidate_ids], dim=-1)
            truth_ids = (batch["labels"][(batch["input_ids"] == tokenizer.mask_token_id)] == candidate_ids[1]).long()
            total_correct += (preds == truth_ids).sum().item()

            total_samples += batch_size

    metrics = torch.tensor([total_loss, total_correct, total_samples], device=device)
    if ddp:
        dist.all_reduce(metrics, op=dist.ReduceOp.SUM)

    final_loss = metrics[0].item() / metrics[2].item()
    final_acc = metrics[1].item() / metrics[2].item()

    model.train()
    return final_loss, final_acc

def log(train_loss, lr, val_loss, acc, time_taken, global_step):
    wandb.log({
        "train_loss": train_loss,
        "lr": lr,
        "acc": acc,
        "val_loss": val_loss,
        "time/step_s": time_taken,
    }, step=global_step)
    
    print(f"Step:{global_step}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f},  Acc: {acc:.4f}, LR: {lr:.2e}, Time: {time_taken:.2f}s")

model = DDP(raw_model, device_ids=[ddp_local_rank]) if ddp else raw_model

# Learning rate schedule
class CosineScheduler:
    def __init__(self, max_lr, min_lr, total_steps, warmup_ratio):
        self.max_lr = float(max_lr)
        self.min_lr = float(min_lr)
        self.total_steps = int(total_steps)
        self.warmup_steps = int(total_steps * warmup_ratio)

    def __call__(self, it):
        if it < self.warmup_steps:
            return self.max_lr * (it + 1) / self.warmup_steps
        
        if it > self.total_steps:
            return self.min_lr
        
        decay_ratio = (it - self.warmup_steps) / (self.total_steps - self.warmup_steps)
        coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))

        return self.min_lr + coeff * (self.max_lr - self.min_lr)

n_steps = len(loader_train) * args.epochs // args.grad_accum_steps
lr_sched = CosineScheduler(max_lr=args.lr, min_lr=args.lr*0.1, total_steps=n_steps, warmup_ratio=args.warmup_ratio)

class BpTrainer():
    def __init__(self, model, optimizer, lr_sched):
        self.model = model
        self.lr_sched = lr_sched
        self.optimizer = optimizer

    def train(self):
        train_loss = 0.0
        global_step = 0
        while True: 
            torch.cuda.synchronize()
            start_time = time.time() 
            lr = self.lr_sched(global_step) 

            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr 

            for _ in range(args.grad_accum_steps):
                x = next(it_train)
                x = {k: v.to(device, non_blocking=True) for k, v in x.items()} 
                out = self.model(**x) 
                loss = out.loss / args.grad_accum_steps
                train_loss += loss.detach().item()
                loss.backward() 

            self.optimizer.step() 
            self.optimizer.zero_grad(set_to_none=True) 
            torch.cuda.synchronize()
            time_taken = time.time() - start_time

            if global_step % args.logging_steps == 0: 
                if global_step != 0: 
                    train_loss /= args.logging_steps
                val_loss, acc = evaluate() 
                if master_process: 
                    log(train_loss, lr, val_loss, acc, time_taken, global_step)
                train_loss = 0.0

            global_step += 1
            if global_step >= n_steps: 
                break

class ZOTrainer(ABC):
    def __init__(self, model, lr_sched, init_seed):
        self.model = model
        self.lr_sched = lr_sched
        self.init_seed = init_seed

    def perturb_params(self, mu, seed, projected_grad=None):
        g = torch.Generator(device=device).manual_seed(seed)
        scale = mu
        with torch.no_grad():
            for p in self.model.parameters():
                if not p.requires_grad: 
                    continue
                r = torch.randn(size=p.shape, device=device, generator=g, dtype=p.dtype)
                
                if projected_grad: 
                    scale = -projected_grad * mu
                p.add_(r, alpha=scale)

            torch.cuda.synchronize()

    @abstractmethod
    def train(self): 
        pass

class RGFTrainer(ZOTrainer):
    def __init__(self, model, lr_sched, mu_sched, init_seed):
        super().__init__(model, lr_sched, init_seed)
        self.mu_sched = mu_sched

    def train(self):
        train_loss = 0.0 
        global_step = 0

        while True:
            torch.cuda.synchronize()
            start_time = time.time() 
            seed = init_seed + global_step
            lr = self.lr_sched(global_step)
            mu  = self.mu_sched(global_step)
            x = next(it_train)
            x = {k: v.to(device, non_blocking=True) for k, v in x.items()} 
            losses = torch.empty(2, device=device)

            for i in range(2):
                if i == 1: 
                    self.perturb_params(mu, seed)
                with torch.inference_mode():
                    self.model.eval()
                    losses[i] = self.model(**x).loss

            train_loss += losses[0].item()
            self.perturb_params(-mu, seed) 

            if ddp: 
                dist.all_reduce(losses, op=dist.ReduceOp.AVG)

            projected_grad = (losses[1] - losses[0]) / mu
            self.perturb_params(lr, seed, projected_grad=projected_grad)
            torch.cuda.synchronize()
            time_taken = time.time() - start_time 

            if global_step % args.logging_steps == 0: 
                if global_step != 0: 
                    train_loss /= args.logging_steps
                val_loss, acc = evaluate() 
                if master_process: 
                    log(train_loss, lr, val_loss, acc, time_taken, global_step)
                train_loss = 0.0

            global_step += 1
            if global_step >= n_steps:
                break

class NCRSTrainer(ZOTrainer):
    def __init__(self, model, lr_sched, init_seed):
        super().__init__(model, lr_sched, init_seed)

    def train(self):
        train_loss = 0.0 
        global_step = 0

        while True:
            torch.cuda.synchronize()
            start_time = time.time() 
            seed = init_seed + global_step
            lr = self.lr_sched(global_step) 
            losses = torch.empty(2, device=device)
            x = next(it_train)
            x = {k: v.to(device, non_blocking=True) for k, v in x.items()} 

            for i in range(2):
                if i == 1: 
                    self.perturb_params(lr, seed)
                with torch.inference_mode():
                    self.model.eval()
                    losses[i] = self.model(**x).loss

            train_loss += losses[0].item()
            self.perturb_params(-lr, seed) 

            if ddp: 
                dist.all_reduce(losses, op=dist.ReduceOp.AVG)

            if losses[1] < losses[0]: 
                self.perturb_params(lr, seed)

            torch.cuda.synchronize()
            time_taken = time.time() - start_time

            if global_step % args.logging_steps == 0: 
                if global_step != 0: train_loss /= args.logging_steps
                val_loss, acc = evaluate() 
                if master_process: log(train_loss, lr, val_loss, acc, time_taken, global_step)
                train_loss = 0.0

            global_step += 1
            if global_step >= n_steps: 
                break

if master_process:
    wandb.init(project="intrinsic-dim", name=f"{args.model_name}-{args.opt}", group=args.dataset, config=config)

trainer = None
match args.opt:
    case "SGD":
        optimizer = torch.optim.SGD(model.parameters())
        trainer = BpTrainer(model, optimizer, lr_sched)
    case "NCRS":
        trainer = NCRSTrainer(model, lr_sched, init_seed)
    case "RGF":
        mu_sched = CosineScheduler(max_lr=args.mu, min_lr=args.mu, total_steps=n_steps, warmup_ratio=0)
        trainer = RGFTrainer(model, lr_sched, mu_sched, init_seed)

print(f"Starting training with {args.opt}")
trainer.train()

print("Done")
if ddp:
    dist.destroy_process_group()

wandb.finish()