import argparse
import deepspeed

parser = argparse.ArgumentParser(description='sp')
parser.add_argument('--basepath', type=str, required=True)
parser.add_argument('--trainpath', type=str, required=True)
parser.add_argument('--testpath', type=str, required=True)
parser.add_argument('--savedir', type=str, required=True)
parser.add_argument('--config_path', type=str, required=True)
parser.add_argument('--continue_eagle_path', type=str, default=None)
parser.add_argument('--confidence_loss_type', type=str, default="rmsle")
parser.add_argument('--cache_path', type=str, default="outputs/cache.pt")
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
parser.add_argument("--max_len", type=int, default=4096)
parser.add_argument("--dtype", type=str, default="float16")
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
import json
import re
import sys
import os
import shutil


deepspeed_config = args.deepspeed_config
with open(deepspeed_config) as f:
    ds_config = json.load(f)
train_config = {
    "bs": ds_config["train_micro_batch_size_per_gpu"],
    "num_epochs": 20,
    "num_workers": 2,
    "max_len": 4096,
    "config_path": args.config_path,
}

from safetensors import safe_open
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
import torch
from cnets import padding

torch.backends.cuda.matmul.allow_tf32 = True
from accelerate.utils import set_seed

set_seed(0)
from cnets import Model
from configs import EConfig
from data_utils import build_dataset_confidence_rank
from datasets import load_dataset
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from tqdm import tqdm
# import accelerate
import numpy as np
from transformers import PreTrainedTokenizerBase, get_linear_schedule_with_warmup


if global_rank == 0:
    import wandb


def build_dataset_rank(
        tokenizer, datapath
):

    ds = load_dataset('json', data_files=datapath)
    ds = ds['train']
    ds = ds.shuffle(seed=42)
    ds1 = ds
    original_columns1 = ds1.column_names
    num_proc = 8

    def preprocess_function(examples):
        new_examples = {
            "attention_mask": [],
            "input_ids": [],
            "loss_mask": []
        }
        for i in range(len(examples['id'])):
            messages = [
                {"role": "system",
                 "content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."},
            ]
            convroles = ["user", "assistant"]
            roles = {"human": "user", "gpt": "assistant"}
            source = examples['conversations'][i]
            if not source:
                continue
            if roles[source[0]["from"]] != "user":
                # Skip the first one if it is not from human
                source = source[1:]
            for j, sentence in enumerate(source):
                role = roles[sentence["from"]]
                assert role == convroles[j % 2], f"{i}"
                # if sentence["from"]=="gpt":
                #     sentence["value"]=" "+sentence["value"]
                messages.append(
                    {"role": role, "content": sentence["value"]}
                )
            conversation = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False,
            )

            if not tokenizer.pad_token_id:
                tokenizer.pad_token_id = tokenizer.unk_token_id

            input_ids = tokenizer(
                conversation,
                return_tensors="pt",
                max_length=2048,
                add_special_tokens=False,
            ).input_ids[0]
            loss_mask = torch.ones_like(input_ids)
            # print(i)

            sep = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"

            total_len = len(input_ids)

            sep2 = "<|eot_id|><|start_header_id|>user<|end_header_id|>"
            turns = conversation.split(sep2)

            turns[1] = turns[0] + sep2 + turns[1]
            turns = turns[1:]

            cur_len = 1
            loss_mask[:cur_len] = 0
            for i, turn in enumerate(turns):
                if turn == "":
                    break
                turn_len = len(tokenizer(turn).input_ids)

                parts = turn.split(sep)
                if len(parts) != 2:
                    break
                parts[0] += sep
                # "-2" is hardcoded for the Llama tokenizer to make the offset correct.
                instruction_len = len(tokenizer(parts[0]).input_ids) - 1

                # Ignore the user instructions
                if i == 0:
                    loss_mask[cur_len: cur_len + instruction_len - 2] = 0
                else:
                    loss_mask[cur_len - 3: cur_len + instruction_len + 1] = 0
                cur_len += turn_len
                if i != 0:
                    cur_len += 3
                # cur_len+=2

                # if i != 0 and not tokenizer.legacy:
                #     # The legacy and non-legacy modes handle special tokens differently
                #     cur_len -= 1

            loss_mask[cur_len:] = 0
            attention_mask = torch.ones_like(loss_mask)

            # new_examples["conversation"].append(conversation)
            new_examples["input_ids"].append(input_ids[None, :])
            new_examples["loss_mask"].append(loss_mask[None, :])
            new_examples["attention_mask"].append(attention_mask[None, :])

        return new_examples

    ds1 = ds1.map(
        preprocess_function,
        batched=True,
        num_proc=num_proc,
        remove_columns=original_columns1,
        load_from_cache_file=False
    )


    ds1.set_format(type="torch")
    return ds1


class DataCollatorWithPadding:

    def paddingtensor(self, intensors, N):
        B, n, S = intensors.shape
        # padding_tensor = torch.zeros(B, N - n, S,dtype=intensors.dtype)
        padding_tensor = torch.zeros(B, N - n, S, dtype=intensors.dtype)
        outtensors = torch.cat((intensors, padding_tensor), dim=1)
        return outtensors

    def paddingtensor2D(self, intensors, N):
        B, n = intensors.shape
        padding_tensor = torch.zeros(B, N - n, dtype=intensors.dtype)
        outtensors = torch.cat((intensors, padding_tensor), dim=1)
        return outtensors
    
    def paddingtensor3D(self, intensors, N):
        B, n, D = intensors.shape
        padding_tensor = torch.zeros(B, N - n, D, dtype=intensors.dtype)
        outtensors = torch.cat((intensors, padding_tensor), dim=1)
        return outtensors

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        max_length = max(item['input_ids'].shape[1] for item in features)
        batch_input_ids = torch.cat([self.paddingtensor2D(item['input_ids'], max_length) for item in features])
        batch_attention_mask = torch.cat(
            [self.paddingtensor2D(item['attention_mask'], max_length) for item in features])
        batch_loss_mask = torch.cat(
            [self.paddingtensor2D(item['loss_mask'], max_length) for item in features])
        if features[0]['confidence'].ndim == 2:
            batch_confidence = torch.cat([self.paddingtensor2D(item['confidence'], max_length) for item in features])
        else:
            batch_confidence = torch.cat([self.paddingtensor3D(item['confidence'], max_length) for item in features])
        # batch_confidence_loss_mask = torch.cat([self.paddingtensor2D(item['confidence_loss_mask'], max_length) for item in features])

        batch = {
            "input_ids": batch_input_ids,
            "attention_mask": batch_attention_mask,
            "loss_mask": batch_loss_mask,
            "confidence": batch_confidence,
            # "confidence_loss_mask": batch_confidence_loss_mask
        }
        return batch


tokenizer = AutoTokenizer.from_pretrained(args.basepath)

model_type = AutoConfig.from_pretrained(args.basepath).architectures[0]
print(f"model_type: {model_type}")
if model_type == "LlamaForCausalLM":
    tokenizer.start_think_id = tokenizer.encode('<think>', add_special_tokens=False)[0]
    tokenizer.stop_think_id = tokenizer.encode('</think>', add_special_tokens=False)[0]
    tokenizer.start_think_text = '<think>'
    tokenizer.stop_think_text = '</think>'
    traindataset = build_dataset_confidence_rank(tokenizer, args.trainpath, args.max_len)
    testdataset = build_dataset_confidence_rank(tokenizer, args.testpath, args.max_len)
elif model_type == "Qwen3ForCausalLM":
    tokenizer.start_think_id = tokenizer.encode('<think>', add_special_tokens=False)[0]
    tokenizer.stop_think_id = tokenizer.encode('</think>', add_special_tokens=False)[0]
    tokenizer.start_think_text = '<think>'
    tokenizer.stop_think_text = '</think>'
    traindataset = build_dataset_confidence_rank(tokenizer, args.trainpath, args.max_len)
    testdataset = build_dataset_confidence_rank(tokenizer, args.testpath, args.max_len)
else:
    raise ValueError(f"Unknown model type: {model_type}")

config = EConfig.from_pretrained(train_config["config_path"])
model = Model(
    config, 
    path=args.basepath, 
    dtype=args.dtype,
    tokenizer=tokenizer,
    load_emb=True, 
    load_head=True, 
    model_type=model_type, 
    continue_eagle_path=args.continue_eagle_path, 
    confidence_loss_type=args.confidence_loss_type
)
print(f"EAModel: {model}")
if args.continue_eagle_path is None:
    model.scandata(args.trainpath, args.basepath, args.cache_path)


criterion = nn.SmoothL1Loss(reduction="none")

num_epochs = train_config["num_epochs"]

model_engine, optimizer, _, _ = deepspeed.initialize(args=args,
                                                     model=model,
                                                     model_parameters=model.parameters(),
                                                     )

global_rank = deepspeed.comm.get_rank()
rank = deepspeed.comm.get_local_rank()
world_size = deepspeed.comm.get_world_size()

os.makedirs(args.savedir, exist_ok=True)


sampler = DistributedSampler(testdataset, num_replicas=world_size, rank=global_rank, shuffle=False)
test_loader = DataLoader(testdataset, batch_size=train_config["bs"], sampler=sampler, num_workers=4, pin_memory=True,
                         collate_fn=DataCollatorWithPadding())

train_sampler = DistributedSampler(traindataset, num_replicas=world_size, rank=global_rank, shuffle=True)
train_loader = DataLoader(traindataset, batch_size=train_config["bs"], sampler=train_sampler, num_workers=4,
                          pin_memory=True,
                          collate_fn=DataCollatorWithPadding())
print(f"train_loader: {train_loader} len(train_loader)={len(train_loader)} len(traindataset)={len(traindataset)} len(test_loader)={len(test_loader)} len(testdataset)={len(testdataset)}")

def find_max_state_with_file(directory, filename="zero_to_fp32.py"):
    max_a = -1
    for subdir in os.listdir(directory):
        match = re.match(r"state_(\d+)", subdir)
        if match:
            a_value = int(match.group(1))
            subdir_path = os.path.join(directory, subdir)
            file_path = os.path.join(subdir_path, filename)
            if os.path.isdir(subdir_path) and os.path.exists(file_path):
                max_a = max(max_a, a_value)
    if max_a == -1:
        return None, 0
    return f"{directory}/state_{max_a}", max_a + 1


checkpoint_path, start_epoch = find_max_state_with_file(args.savedir)
if checkpoint_path:
    print(f"load from {checkpoint_path}")
    model_engine.load_checkpoint(checkpoint_path)

epsilon: float = 1e-6
closs_ratio: float = 0.5

for epoch in range(start_epoch, num_epochs):
    train_sampler.set_epoch(epoch+1)
    print(f"Now training epoch {epoch}")

    model.train()
    epoch_acces = [[] for _ in range(model.length)]
    epoch_plosses = [[] for _ in range(model.length)]


    for batch_idx, data in enumerate(tqdm(train_loader)):

        model.zero_grad()

        plosses, vlosses, closses, acces = model_engine(input_ids=data["input_ids"].to(rank),
                                               attention_mask=data["attention_mask"].to(rank),
                                               loss_mask=data["loss_mask"],
                                               confidence=data["confidence"])

        ploss_weight = [0.8 ** i for i in range(len(plosses))]
        ploss = sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))])
        loss = ploss

        if global_rank == 0:
            logdict = {"train/lr": optimizer.optimizer.param_groups[0]["lr"]}

        if closses is not None:
            
            if isinstance(closses, list):
                # closs = sum([ploss_weight[i] * closs[i] for i in range(len(closs))])
                confidence_mse, progress_mse, remain_rmsle = closses
                with torch.no_grad():
                    confidence_mse_val = confidence_mse.item()
                    progress_mse_val = progress_mse.item()
                    remain_rmsle_val = remain_rmsle.item()

                    confidence_weight = remain_rmsle_val / (confidence_mse_val + epsilon) if confidence_mse_val > epsilon else 1.0
                    progress_weight = remain_rmsle_val / (progress_mse_val + epsilon) if progress_mse_val > epsilon else 1.0

                closs = confidence_weight * confidence_mse + progress_weight * progress_mse + remain_rmsle
                
                if global_rank == 0:
                    logdict[f"train/confidence_mse"] = confidence_mse_val
                    logdict[f"train/progress_mse"] = progress_mse_val
                    logdict[f"train/remain_rmsle"] = remain_rmsle_val
            else:
                closs = closses

            with torch.no_grad():
                class_loss_val = ploss.float().item()
                conf_loss_val = closs.float().item()
                if global_rank == 0:
                    logdict[f"train/confidence_loss"] = conf_loss_val
                dynamic_weight = (class_loss_val / (conf_loss_val + epsilon)) if conf_loss_val > epsilon else 1.0
            loss = ploss + (dynamic_weight * closs_ratio) * closs
        else:
            closs = None
        print(f"ploss={ploss} closs={closs} loss={loss} dtype={loss.dtype}")

        # loss = loss.to(dtype=torch.bfloat16)

        model_engine.backward(loss)

        model_engine.step()

        if global_rank == 0:
            for i in range(len(plosses)):
                logdict[f"train/ploss_{i}"] = plosses[i].item()
            for i in range(len(acces)):
                logdict[f"train/acc_{i}"] = acces[i]
            wandb.log(logdict)
        epoch_acces = [epoch_acces[i] + [acces[i]] for i in range(len(acces))]
        epoch_plosses = [epoch_plosses[i] + [plosses[i].item()] for i in range(len(plosses))]

    torch.cuda.empty_cache()

    for i in range(len(epoch_acces)):
        acc_i = torch.tensor(epoch_acces[i]).cuda().mean()
        deepspeed.comm.all_reduce(acc_i, op=deepspeed.comm.ReduceOp.AVG)
        acc_i = acc_i.item()
        if global_rank == 0:
            wandb.log({f"train/epochacc_{i}": acc_i})
            print(f"Train Epoch [{epoch + 1}/{num_epochs}], position {i},  Acc: {acc_i:.2f}")

    for i in range(len(epoch_plosses)):
        loss_i = torch.tensor(epoch_plosses[i]).cuda().mean()
        deepspeed.comm.all_reduce(loss_i, op=deepspeed.comm.ReduceOp.AVG)
        loss_i = loss_i.item()
        if global_rank == 0:
            wandb.log({f"train/epochploss_{i}": loss_i})
            print(f"Train Epoch [{epoch + 1}/{num_epochs}], position {i}, pLoss: {loss_i:.2f}")

    epoch_acces = [[] for _ in range(model.length)]
    epoch_plosses = [[] for _ in range(model.length)]
    if args.confidence_loss_type == "confidence_progress_remain":
        epoch_closses = [[] for _ in range(3)]
    else:
        epoch_closses = []

    for batch_idx, data in enumerate(tqdm(test_loader)):
        with torch.no_grad():
            plosses, vlosses, closses, acces = model_engine(input_ids=data["input_ids"].to(rank),
                                                   attention_mask=data["attention_mask"].to(rank),
                                                   loss_mask=data["loss_mask"],
                                                   confidence=data["confidence"])
            epoch_acces = [epoch_acces[i] + [acces[i]] for i in range(len(acces))]
            epoch_plosses = [epoch_plosses[i] + [plosses[i].item()] for i in range(len(plosses))]
            if args.confidence_loss_type == "confidence_progress_remain":
                epoch_closses = [epoch_closses[i] + [closses[i].item()] for i in range(len(closses))]
            else:
                epoch_closses.append(closses.item())

    for i in range(len(epoch_acces)):
        acc_i = torch.tensor(epoch_acces[i]).cuda().mean()
        deepspeed.comm.all_reduce(acc_i, op=deepspeed.comm.ReduceOp.AVG)
        acc_i = acc_i.item()
        if global_rank == 0:
            wandb.log({f"test/epochacc_{i}": acc_i})
            print(f"Test Epoch [{epoch + 1}/{num_epochs}], position {i},  Acc: {acc_i:.2f}")

    if args.confidence_loss_type == "confidence_progress_remain":
        for i in range(len(epoch_closses)):
            log_name = ["confidence_mse", "progress_mse", "remain_rmsle"]
            loss_i = torch.tensor(epoch_closses[i]).cuda().mean()
            deepspeed.comm.all_reduce(loss_i, op=deepspeed.comm.ReduceOp.AVG)
            loss_i = loss_i.item()
            if global_rank == 0:
                wandb.log({f"test/epoch_{log_name[i]}": loss_i})
                print(f"Test Epoch [{epoch + 1}/{num_epochs}], position {i}, {log_name[i]} cLoss: {loss_i:.2f}")
    else:
        loss_i = torch.tensor(epoch_closses).cuda().mean()
        deepspeed.comm.all_reduce(loss_i, op=deepspeed.comm.ReduceOp.AVG)
        loss_i = loss_i.item()
        if global_rank == 0:
            wandb.log({f"test/epochcloss": loss_i})
            print(f"Test Epoch [{epoch + 1}/{num_epochs}], cLoss: {loss_i:.2f}")

    for i in range(len(epoch_plosses)):
        loss_i = torch.tensor(epoch_plosses[i]).cuda().mean()
        deepspeed.comm.all_reduce(loss_i, op=deepspeed.comm.ReduceOp.AVG)
        loss_i = loss_i.item()
        if global_rank == 0:
            wandb.log({f"test/epochploss_{i}": loss_i})
            print(f"Test Epoch [{epoch + 1}/{num_epochs}], position {i}, pLoss: {loss_i:.2f}")

    print(f"Epoch [{epoch + 1}/{num_epochs}] completed")
    print(f"Saving 16bit model to path {args.savedir}/state_{epoch}")
    model_engine.save_16bit_model(f"{args.savedir}/state_{epoch}", exclude_frozen_parameters=True)

    try:
        with open(args.config_path, "r") as f:
            config = json.load(f)
        config["architectures"] = ["Eagle3LlamaForCausalLM"]
        config["model_type"] = "llama"
        config["confidence_loss_type"] = args.confidence_loss_type
        with open(f"{args.savedir}/state_{epoch}/config.json", "w") as f:
            json.dump(config, f, indent=4)
    except Exception as e:
        print(f"copy config.json failed: {e}")

    if epoch % 5 == 0:
        deepspeed.DeepSpeedEngine.save_checkpoint(model_engine, save_dir=f"{args.savedir}/state_{epoch}")
