import re
import random
import time
import os
import itertools
import tqdm
import argparse
from functools import partial

import torch
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader

import numpy as np

from utils.inference import generate
from utils.optimization import get_model_name, load_model, get_optimizer
from utils.optimization import GRPOLoss, get_per_token_logps
from utils.data import get_dataloader, collate_train, collate_train_vision
from utils.group import group_advantages
from utils.grpo import postprocess_examples
from utils.other import set_seed
from utils.evaluation import get_eval_metrics

parser = argparse.ArgumentParser()

# Experiment
parser.add_argument("--save_name", type=str, default='')
parser.add_argument('--model_nickname', default="aya_8B", type=str)
parser.add_argument('--dataset_names', nargs="+", default=["yearMCQA"], type=str)
parser.add_argument('--num_questions', default=0, type=int)
parser.add_argument('--checkpoint', default='', type=str)
parser.add_argument('--epochs', default=1, type=int)

# Eval
parser.add_argument('--eval_num_questions', default=4096, type=int)
parser.add_argument('--eval_temperature', default=1.0, type=float)
parser.add_argument('--eval_num_samples', default=4, type=float)
parser.add_argument('--eval_num_prefill_batches', default=4, type=float)
parser.add_argument('--eval_max_new_tokens', default=1280, type=int)
parser.add_argument('--eval_every_n_questions', default=4096, type=int)

# Optimization
parser.add_argument('--train_batch_size', default=16, type=int)
parser.add_argument('--gradient_accumulation', default=1, type=int)
parser.add_argument("--optimizer", type=str, default='adam')
parser.add_argument("--learning_rate", type=str, default="5e-6")
parser.add_argument('--kl_weight', default=0.01, type=float)
parser.add_argument('--max_norm', default=1.0, type=float)
parser.add_argument('--min_length', default=64, type=int)

# rollout
parser.add_argument('--num_samples', default=32, type=int)
parser.add_argument('--temperature', default=1.0, type=float)
parser.add_argument('--rollouts_per_step', default=32, type=int)
parser.add_argument('--gen_batch_size', default=32, type=int)
parser.add_argument('--max_length', default=1024, type=int)
parser.add_argument('--max_prompt_length', default=320, type=int)
parser.add_argument("--method", type=str, default='drgrpo_mcq')

# misc
parser.add_argument('--seed', default=1337, type=int)
parser.add_argument('--workspace_dir', type=str)
parser.add_argument('--data_dir', default="./data/datasets", type=str)
parser.add_argument('--output_dir', default="./outputs/", type=str)
parser.add_argument('--hf_token', default="", type=str)
parser.add_argument("--compile", type=str, default="yes")

args = parser.parse_args()
get_bool = lambda s: s.lower() == "yes"
args.compile = get_bool(args.compile)

args.learning_rate = eval(args.learning_rate)
args.prefill_batch_size = min(args.rollouts_per_step, args.gen_batch_size)
args.num_prefill_batches = args.rollouts_per_step / args.prefill_batch_size
assert args.num_prefill_batches.is_integer()
args.num_prefill_batches = int(args.num_prefill_batches)

os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
cpu_device = torch.device("cpu")
set_seed(args.seed, deterministic=False)

print('dataset_names', args.dataset_names)

save_name = f"{args.save_name}_{args.model_nickname}_{'_'.join(args.dataset_names)}_ML{args.min_length}_{args.method}_L{args.max_length}_BS{args.train_batch_size}_S{args.num_samples}_R{args.rollouts_per_step}"

data_dir = os.path.join(args.workspace_dir, args.data_dir)
output_dir = os.path.join(args.output_dir, save_name + f'_s{args.seed}', str(int(time.time())))
eval_dir = os.path.join(output_dir, 'eval')
output_dir = os.path.join(args.workspace_dir, output_dir)
os.makedirs(eval_dir, exist_ok=True)

print('output_dir', output_dir)
print('eval_dir', eval_dir)

for key, value in dict(vars(args)).items():
    print(f"{key}: {value}")


def main():
    model_name, is_vision = get_model_name(args.model_nickname)

    reference_model, _ = load_model(model_name, hf_token=args.hf_token, device=device, compile=args.compile)
    for name, param in reference_model.named_parameters():
        param.requires_grad = False
    reference_model.eval()
    model, tokenizer = load_model(model_name, hf_token=args.hf_token, device=device, compile=args.compile)

    optimizer = get_optimizer(model, args)

    last_example_idx = None
    num_seen_questions = 0
    if args.checkpoint:
        model, optimizer, last_example_idx, num_seen_questions = load_checkpoint(model, optimizer)
        last_example_idx = last_example_idx.rsplit('_', 1)[0]

    train_dataloader, eval_dataloaders = get_data(model_name, tokenizer, last_example_idx, is_vision)
    train(model, reference_model, tokenizer, train_dataloader, eval_dataloaders, optimizer, is_vision, num_seen_questions=num_seen_questions)



def evaluate(model, tokenizer, dataloader, num_samples, name, num_seen_questions, is_vision):
    print(f'Run evaluation')
    model.eval()

    eval_iter = iter(dataloader)
    pbar_eval = tqdm.tqdm(total=round(len(dataloader) / args.eval_num_prefill_batches), desc="Eval", unit="step")

    t = time.time()
    generation_outputs = dict()
    while True:
        example_batches = list(itertools.islice(eval_iter, int(args.eval_num_prefill_batches)))
        if len(example_batches) == 0:
            break

        with torch.inference_mode():
            _generation_outputs, _ = generate(model,
                                              example_batches,
                                              max_new_tokens=args.eval_max_new_tokens,
                                              tokenizer=tokenizer,
                                              temperature=args.eval_temperature,
                                              num_samples=num_samples,
                                              gen_batch_size=args.gen_batch_size,
                                              pbar=True,
                                              is_vision=is_vision
                                             )
        generation_outputs.update(_generation_outputs)
        pbar_eval.update()

    eval_metrics, completions_sampled = get_eval_metrics(generation_outputs, tokenizer)
    eval_metrics[f'eval_{name}/time'] = round(time.time() - t, 3)
    eval_metrics['num_seen_questions'] = num_seen_questions
    print('eval_metrics', eval_metrics)
    return eval_metrics



def train(model, reference_model, tokenizer, train_dataloader, eval_dataloaders, optimizer, is_vision, num_seen_questions=0):

    objective = GRPOLoss(kl_weight=args.kl_weight)
    if args.compile:
        objective = torch.compile(objective, dynamic=True)

    train_iter = iter(train_dataloader)
    num_seen_questions = num_seen_questions
    global_step = 0
    num_optimizer_steps = 0
    last_eval = int(num_seen_questions / args.eval_every_n_questions) + 1
    while True:
        example_batches = list(itertools.islice(train_iter, int(args.num_prefill_batches)))
        if len(example_batches) == 0:
            break

        num_seen_questions += sum([len(examples) for examples, _ in example_batches])

        model.eval()
        examples = rollout(example_batches, model, tokenizer, is_vision)
        last_example_idx = examples[-1]['idx']
        print('last_example_idx', last_example_idx)

        random.Random(args.seed + global_step).shuffle(examples)
        num_avg_action_tokens = np.mean([len(example['generated_ids']) for example in examples])

        _collate_func = collate_train_vision if is_vision else collate_train
        data_sampler = DataLoader(
            examples,
            batch_size=args.train_batch_size // args.gradient_accumulation,
            drop_last=False,
            collate_fn=partial(_collate_func, tokenizer=tokenizer),
            num_workers=1
        )

        kl_batch = 0
        loss_batch = 0

        optimizer.zero_grad()
        num_example_gradients = 0
        model.train()
        for i, batch in enumerate(data_sampler):
            inputs = batch['inputs'].to(device)
            advantages = batch['advantage'].to(device)
            completion_mask = batch['completion_mask'].to(device)
            num_logits_to_keep = completion_mask.shape[1]
            log_probs_old = batch['log_probs_old'].detach().to(device)

            with torch.no_grad():
                reference_model.to(device)
                log_probs_ref = get_per_token_logps(reference_model, inputs, num_logits_to_keep)
                log_probs_ref = log_probs_ref.detach()

            log_probs = get_per_token_logps(model, inputs, num_logits_to_keep)

            loss, kl = objective(log_probs,
                                 log_probs_old,
                                 log_probs_ref,
                                 completion_mask,
                                 advantages.unsqueeze(-1),
                                 norm_items_per_row=num_avg_action_tokens)

            loss.backward()
            num_example_gradients += len(advantages)

            kl_batch += kl.item()
            loss_batch += loss.item()
            if (i + 1) % args.gradient_accumulation == 0:
                kl = kl_batch / args.gradient_accumulation
                grad_norm = clip_grad_norm_(model.parameters(), max_norm=args.max_norm)
                print(f"kl={kl: .4f}, grad_norm={grad_norm: .4f}")
                kl_batch = 0
                loss_batch = 0

                reference_model.to('cpu')
                optimizer.step()
                num_optimizer_steps += 1
                optimizer.zero_grad()
                num_example_gradients = 0

        if num_example_gradients > 32:
            reference_model.to('cpu')
            optimizer.step()
            num_optimizer_steps += 1

        reference_model.to('cpu')
        optimizer.zero_grad()
        global_step += 1

        if num_seen_questions / args.eval_every_n_questions >= last_eval:
            save_model(model, optimizer, num_seen_questions, last_example_idx)
            for eval_dataloader, eval_name, eval_num_samples in eval_dataloaders:
                evaluate(model, tokenizer, eval_dataloader, eval_num_samples, name=eval_name, num_seen_questions=num_seen_questions, is_vision=is_vision)
            last_eval += 1
    save_model(model, optimizer, num_seen_questions, last_example_idx)


def load_checkpoint(model, optimizer):
    checkpoint_path = args.checkpoint
    print('load', checkpoint_path)
    checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)

    state_dict = checkpoint['model_state_dict']
    #if any(key.startswith('_orig_mod.') for key in state_dict.keys()):
    #    state_dict = {key.replace('_orig_mod.', ''): value for key, value in state_dict.items()}
    model.load_state_dict(state_dict)

    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    num_seen_questions = checkpoint['num_seen_questions']
    last_example_idx = checkpoint['last_example_idx']
    print('num_seen_questions', num_seen_questions)
    print('last_example_idx', last_example_idx)
    return model, optimizer, last_example_idx, num_seen_questions

def save_model(model, optimizer, num_seen_questions, last_example_idx, max_checkpoints=5):
    save_dir = os.path.join(output_dir, 'checkpoints')
    os.makedirs(save_dir, exist_ok=True)

    checkpoint_name = f"model_checkpoint_{num_seen_questions}.pt"
    checkpoint_path = os.path.join(save_dir, checkpoint_name)

    # Save both model and optimizer state
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'num_seen_questions': num_seen_questions,
        'last_example_idx': last_example_idx
    }

    torch.save(checkpoint, checkpoint_path)
    print(f"Saved checkpoint: {checkpoint_path}")

    checkpoint_files = []
    for file in os.listdir(save_dir):
        if file.startswith("model_checkpoint_") and file.endswith(".pt"):
            match = re.search(r'model_checkpoint_(\d+)\.pt', file)
            if match:
                questions_count = int(match.group(1))
                file_path = os.path.join(save_dir, file)
                checkpoint_files.append((file_path, questions_count))

    checkpoint_files.sort(key=lambda x: x[1], reverse=True)

    if len(checkpoint_files) > max_checkpoints:
        for checkpoint_path, _ in checkpoint_files[max_checkpoints:]:
            try:
                os.remove(checkpoint_path)
                print(f"Removed old checkpoint: {checkpoint_path}")
            except OSError as e:
                print(f"Error removing {checkpoint_path}: {e}")

def get_data(model_name, tokenizer, last_example_idx, is_vision):
    train_dataloader = get_dataloader(data_dir,
                                      args.dataset_names,
                                      split_id=0,
                                      model_name=model_name,
                                      seed=args.seed,
                                      max_prompt_length=args.max_prompt_length,
                                      tokenizer=tokenizer,
                                      prefetch=args.num_prefill_batches * 2,
                                      batch_size=args.prefill_batch_size,
                                      epochs=args.epochs,
                                      num_questions=args.num_questions,
                                      last_example_idx=last_example_idx,
                                      is_vision=is_vision
                                      )
    eval_dataloaders = list()
    for dataset_name in args.dataset_names:
        eval_dataloader = get_dataloader(data_dir,
                                         [dataset_name],
                                         split_id=1,
                                         model_name=model_name,
                                         seed=1337,
                                         max_prompt_length=args.max_prompt_length,
                                         tokenizer=tokenizer,
                                         prefetch=args.num_prefill_batches * 2,
                                         batch_size=args.gen_batch_size,
                                         epochs=1,
                                         num_questions=args.eval_num_questions,
                                         is_eval=True,
                                         is_vision=is_vision
                                         )
        eval_dataloaders.append((eval_dataloader, dataset_name, args.eval_num_samples))
    return train_dataloader, eval_dataloaders


def rollout(example_batches, model, tokenizer, is_vision):
    num_samples = args.num_samples
    temperature = args.temperature
    min_length = args.min_length
    assert temperature > 0.0


    with torch.inference_mode():
        generation_outputs, _ = generate(model,
                                         example_batches,
                                         max_new_tokens=args.max_length,
                                         tokenizer=tokenizer,
                                         temperature=temperature,
                                         num_samples=num_samples,
                                         gen_batch_size=args.gen_batch_size,
                                         pbar=True,
                                         is_vision=is_vision,
                                         return_inputs_embeds=is_vision
                                         )

    postprocessed_outputs = list()
    for example_id, examples in generation_outputs.items():
        examples = examples[:num_samples]
        assert len(examples) == num_samples, len(examples)

        rewards, answers, oracle_answer = postprocess_examples(examples, tokenizer, is_vision, min_length)
        postprocessed_outputs.append((examples, rewards, answers, oracle_answer))

    outputs = list()
    for examples, rewards, answers, oracle_answer in postprocessed_outputs:
        advantages = group_advantages(args.method,
                                      rewards,
                                      num_options=examples[0]['num_options'])

        advantages = advantages.tolist()
        assert len(advantages) == len(examples)
        for advantage, example in zip(advantages, examples):
            example['advantage'] = advantage
            outputs.append(example)

    # scale advantage such that we can compare different advantage calculations
    positive_advantage_mass = sum(example['advantage'] for example in outputs if example['advantage'] > 0.0)
    negative_advantage_mass = abs(sum(example['advantage'] for example in outputs if example['advantage'] < 0.0))

    target_sum = (num_samples / 2) * len(generation_outputs)
    advantage_positive_scaling_factor = target_sum / (positive_advantage_mass + 1e-8)
    advantage_negative_scaling_factor = target_sum / (negative_advantage_mass + 1e-8)
    for example in outputs:
        if example['advantage'] > 0.0:
            example['advantage'] *= advantage_positive_scaling_factor
        else:
            example['advantage'] *= advantage_negative_scaling_factor

    return outputs


if __name__ == "__main__":
    main()
