from transformers import DistilBertForQuestionAnswering, AutoTokenizer
from datasets import load_dataset
import torch, torchopt, nltk, time
from torch.utils.data import DataLoader
from transformers import DefaultDataCollator
from torch import nn
from functools import partial
from posteriors import model_to_function, extract_requires_grad_and_func, fvp
from optree import tree_map
from utils import cg, ggnvp, adjust_damping, _add, _mul, thermo_solve_fvp
import torch.nn.functional as F
from tqdm import tqdm
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType

def preprocess_function(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=384,
        truncation="only_second",
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        answer = answers[i]
        start_char = answer["answer_start"][0]
        end_char = start_char + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise, find the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

# Load the tokenizer and model

model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = DistilBertForQuestionAnswering.from_pretrained(model_name)

peft_config = LoraConfig(
    task_type=TaskType.QUESTION_ANS, 
    inference_mode=False, 
    r=2, 
    lora_alpha=32, 
    lora_dropout=0.1,
    target_modules=[
        "q_lin",
        "k_lin",
        "v_lin",
        "out_lin"
    ]
)

peft_model = get_peft_model(model, peft_config)
peft_model.print_trainable_parameters()

# Load dataset
def prepare_data(batch_size):
    data = load_dataset("squad", split="train[:1000]")
    data = data.train_test_split(test_size=0.2)

    # Preprocess the data
    tokenized_dataset = data.map(preprocess_function, batched=True, remove_columns=data["train"].column_names)

    data_collator = DefaultDataCollator()

    # DataLoader creation
    encoder_max_length = 384
    decoder_max_length = 64

    train_data = DataLoader(
        tokenized_dataset["train"],
        batch_size=batch_size,
        collate_fn=data_collator
    )

    test_data = DataLoader(
        tokenized_dataset["test"],
        batch_size=batch_size,
        collate_fn=data_collator
    )
    return train_data, test_data

batch_size = 32
train_data, test_data = prepare_data(batch_size)


def loss_fn(params, input_ids, attention_mask, start, end, model_fun):
    outputs = model_fun(params, input_ids=input_ids, attention_mask=attention_mask)
    start_logits = outputs.start_logits
    end_logits = outputs.end_logits
    start_loss = F.cross_entropy(start_logits.view(-1, start_logits.size(-1)), start.view(-1))
    end_loss = F.cross_entropy(end_logits.view(-1, end_logits.size(-1)), end.view(-1))
    return start_loss + end_loss

def loss_fn_ngd(params, input_ids, attention_mask, start, end, model_fun):
    outputs = model_fun(params, input_ids=input_ids, attention_mask=attention_mask)
    start_logits = outputs.start_logits
    end_logits = outputs.end_logits
    start_loss = F.cross_entropy(start_logits.view(-1, start_logits.size(-1)), start.view(-1), reduction='none')
    end_loss = F.cross_entropy(end_logits.view(-1, end_logits.size(-1)), end.view(-1), reduction='none')
    return start_loss + end_loss

train_accs: dict = {}
train_losses: dict = {}
test_accs: dict = {}
test_losses: dict = {}
times: dict = {}
optimizers = ["adam"]
learning_rates = [1e-3]
maxiter = 20
damping = 1
lm_damping = False
factor = 1.05
seeds = [0]
eps = 1e-8
num_train_epochs = 5
betas = (0., 0.)
average_regularization = False
step = 0.01
momentum = 0.
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

peft_model_func = model_to_function(peft_model)
sub_params, sub_model_fun = extract_requires_grad_and_func(dict(peft_model.named_parameters()), peft_model_func)
loss = 0

for seed in seeds:
    torch.manual_seed(seed)
    with torch.no_grad():
        for i, optimizer_name in enumerate(optimizers):
            # Select the optimizer
            if optimizer_name == "adam":
                optimizer = torchopt.adam(lr=learning_rates[i], betas=betas, eps=1e-8)
            elif optimizer_name == "adamw":
                optimizer = torchopt.adamw(lr=learning_rates[i], betas=betas, eps=1e-8, weight_decay=0)
            elif optimizer_name == "ngd-ggn" or optimizer_name == "ggn-thermo":
                optimizer = torchopt.adam(lr=learning_rates[i], betas=betas, eps=1e-8)
            
            #Reinitialize model
            peft_model = get_peft_model(model, peft_config)
            sub_params, sub_model_fun = extract_requires_grad_and_func(dict(peft_model.named_parameters()), peft_model_func)
            opt_state = optimizer.init(sub_params) 
            x0 = None
            # Resetting to initial parameters
            

            train_accs[optimizer_name] = []
            train_losses[optimizer_name] = []
            test_accs[optimizer_name] = []
            test_losses[optimizer_name] = []
            times[optimizer_name] = []

            print(f"\nOptimizer {optimizer_name}", "starting")
            for epoch in range(num_train_epochs):
                start_time = time.time()
                epoch_test_loss, epoch_train_loss, num_els = 0, 0, 0
                pbar = tqdm(train_data, desc=f"Epoch {epoch+1}")
                for batch in pbar:
                    input_ids = batch['input_ids']
                    attention_mask = batch['attention_mask']
                    start, end = batch['start_positions'], batch['end_positions']

                    if optimizer_name == "adam":
                        
                        grad = torch.func.grad(loss_fn, 0)(sub_params, input_ids, attention_mask, start, end, sub_model_fun)  

                    elif optimizer_name == "ngd-f":
                        def partial_loss_fn(params):
                            return loss_fn_ngd(params, input_ids, attention_mask, start, end, sub_model_fun)

                        def partial_fvp(v):
                            return fvp(partial_loss_fn, (sub_params,), (v,), normalize=True)[1]

                        grad_0 = torch.func.grad(loss_fn, 0)(sub_params, input_ids, attention_mask, start, end, sub_model_fun)  
                        
                        if x0 is None:
                            x0 = grad_0
                        grad, info = cg(partial_fvp, grad_0, x0=x0, maxiter=maxiter, damping=damping, tol=1e-2)
                        x0 = grad
                    elif optimizer_name == "ngd-ggn":

                        def loss_ggn(outputs):
                            start_logits = outputs[0]
                            end_logits = outputs[1]
                            start_loss = F.cross_entropy(start_logits.view(-1, start_logits.size(-1)), start.view(-1))
                            end_loss = F.cross_entropy(end_logits.view(-1, end_logits.size(-1)), end.view(-1))
                            return start_loss + end_loss
                            
                        def forward(params):
                            output = sub_model_fun(params, input_ids=input_ids, attention_mask=attention_mask)
                            return torch.stack([output.start_logits, output.end_logits])


                        def partial_ggnvp(v):
                            return tree_map(lambda x: x/batch_size, ggnvp(forward, loss_ggn, (sub_params,), (v,), normalize=False)[1])

                        grad_0 = torch.func.grad(loss_fn, 0)(sub_params, input_ids, attention_mask, start, end, sub_model_fun)  
                        if x0 is None:
                            x0 = grad_0
                        grad, _ = cg(partial_ggnvp, grad_0, x0=x0, maxiter=maxiter, damping=damping, tol=1e-2)
                        x0 = grad
                    elif optimizer_name == "ggn-thermo":

                        def loss_ggn(outputs):
                            start_logits = outputs[0]
                            end_logits = outputs[1]
                            start_loss = F.cross_entropy(start_logits.view(-1, start_logits.size(-1)), start.view(-1))
                            end_loss = F.cross_entropy(end_logits.view(-1, end_logits.size(-1)), end.view(-1))
                            return start_loss + end_loss
                            
                        def forward(params):
                            output = sub_model_fun(params, input_ids=input_ids, attention_mask=attention_mask)
                            return torch.stack([output.start_logits, output.end_logits])


                        def partial_ggnvp(v):
                            return tree_map(lambda x: x/batch_size, ggnvp(forward, loss_ggn, (sub_params,), (v,), normalize=False)[1])

                        grad_0 = torch.func.grad(loss_fn, 0)(sub_params, input_ids, attention_mask, start, end, sub_model_fun)  
                        if x0 is None:
                            x0 = grad_0
                        grad = thermo_solve_fvp(partial_ggnvp, grad_0, x0=x0, iterations=maxiter, step=step, damping=damping, average_regularization=average_regularization)
                        x0 = grad

                    loss = loss_fn(sub_params, input_ids, attention_mask, start, end, sub_model_fun)
                    pbar.set_description(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
                    updates, opt_state = optimizer.update(grad, opt_state)
                    if lm_damping:
                        delta_loss = loss_fn(_add(sub_params, _mul(-1,grad)), input_ids, attention_mask, start, end, sub_model_fun) - loss
                        damping = adjust_damping(damping, delta_loss, grad, grad_0, partial_ggnvp, factor=factor)
                    #print("CG error = ", info['error'], "niter = ", info['niter'], "damping = ", damping)
                    sub_params = torchopt.apply_updates(sub_params, updates)
                    train_losses[optimizer_name].append(loss.item())
                    test_losses[optimizer_name].append(loss.item())
                    epoch_train_loss += loss.item()
                    #epoch_train_acc += accuracy(outputs, labels).item()
                    num_els += 1

                times[optimizer_name].append(time.time() - start_time)
                for test_batch in test_data:
                    input_ids = test_batch['input_ids']
                    attention_mask = test_batch['attention_mask']
                    start, end = test_batch['start_positions'], test_batch['end_positions']
                    test_loss = loss_fn(sub_params, input_ids, attention_mask, start, end, sub_model_fun)
                    test_losses[optimizer_name].append(test_loss.item())
                    epoch_test_loss += test_loss.item()
                    num_els += 1

                epoch_train_loss /= num_els
                epoch_test_loss /= num_els



                print(
                    f"Epoch {epoch+1} - Train Loss: {epoch_train_loss:.4f}, Test Loss: {epoch_test_loss:.4f}"
                )

results = {
    "train_losses": train_losses,
    "test_losses": test_losses
}

def save_data():
    # Configuration dictionary
    dir_name = 'data/' +datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    os.makedirs(dir_name, exist_ok=True)
    config = {
        "optimizer_name": optimizer_name,
        "learning_rates": learning_rate,
        "maxiter": maxiter,
        "damping": damping,
        "lm_damping": lm_damping,
        "factor": factor,
        "seeds": seeds,
        "eps": eps,
        "num_train_epochs": num_train_epochs,
        "betas": betas,
        "batch_size": batch_size,
        "step": step,
        "average_regularization": average_regularization,
        "momentum": momentum,
    }
    
    # Writing the dictionary to a text file as JSON
    filename = dir_name+'/config.txt'
    with open(filename, 'w') as file:
        json.dump(config, file, indent=4)
        
    print(f"Configuration saved to {filename}")

    # Save the data
    np.save(dir_name+'/results.npy', results)

save_data()