import torch
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data import DataLoader
import deepspeed
import argparse
from transformers import (
    GPTNeoConfig,
    GPT2Tokenizer,
    SchedulerType,
)
import tqdm
import os
import time
import copy
import gc
import numpy as np
import collections
from preprocess_data import MyDataset
from expand_model.modeling_gpt_neo import GPTNeoForCausalLM
from expand_model.expand import grow_ops
from expand_model.utils import get_scheduler_ex, compute_total_norm, LayerNormEx

import random
from datetime import datetime
from datetime import timedelta
from datetime import timezone

SHA_TZ = timezone(
    timedelta(hours=8),
    name='Asia/Shanghai',
)


def get_time():
    utc_now = datetime.utcnow().replace(tzinfo=timezone.utc)
    beijing_now = utc_now.astimezone(SHA_TZ)
    return beijing_now


parser = argparse.ArgumentParser()
parser.add_argument("--deepspeed_config", type=str, default="./ds_config.json")
parser.add_argument("--output_dir", type=str, default="./output")
parser.add_argument("--learning_rate", type=int, default=1e-4)
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--weight_decay", type=int, default=1e-1)
parser.add_argument("--beta1", type=int, default=0.9)
parser.add_argument("--beta2", type=int, default=0.95)
parser.add_argument("--max_training_step", type=int, default=6000)
parser.add_argument("--save_interval", type=int, default=1000)
parser.add_argument("--eval_interval", type=int, default=100)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--load", type=str, default=None)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Gradient Clipping")

# grow settings
parser.add_argument("--original_hidden_size", type=int, default=768)
parser.add_argument("--original_layer_num", type=int, default=12)
parser.add_argument("--original_head_num", type=int, default=12)
parser.add_argument("--original_intermediate_size", type=int, default=3072)

parser.add_argument("--grow_time", type=int, default=256, help="in how many steps to increase mask to 1")
parser.add_argument("--grow_step", type=int, default=1,
                    help="in which step to start growing")

parser.add_argument("--hidden_size_target", type=int, default=1024)
parser.add_argument("--layer_target", type=int, default=24)
parser.add_argument("--head_target", type=int, default=16)
parser.add_argument("--intermediate_target", type=int, default=4096)

args = parser.parse_args()

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

#deepspeed.init_distributed()

def new_model(config, args):
    new_model = GPTNeoForCausalLM(config)

    return new_model


random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

train_dataset = MyDataset(bin_id=26, num_sample=args.batch_size * args.max_training_step,mini_bin_id=0)

if args.local_rank==0:
    eval_dataset = MyDataset(bin_id="eval", num_sample=400)
    eval_sampler = torch.utils.data.SequentialSampler(eval_dataset)
    eval_data_loader = DataLoader(eval_dataset,
                                  pin_memory=True,
                                  shuffle=False,
                                  batch_size=4,
                                  sampler=eval_sampler)

config = GPTNeoConfig(num_hidden_layers=args.original_layer_num, hidden_size=args.original_hidden_size,
                      intermediate_size=args.original_intermediate_size, num_heads=args.original_head_num,
                      max_position_embeddings=1024)
config_up = copy.deepcopy(config)

if args.load==None:
    model = GPTNeoForCausalLM(config)
else:
    model = GPTNeoForCausalLM(config)
    sd=torch.load(os.path.join(args.load, "pytorch_model.bin"))
    new_sd = {}
    for key in sd.keys():
        new_sd[key[7::]] = sd[key]
    model.load_state_dict(new_sd)

no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": args.weight_decay,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, weight_decay=args.weight_decay,
                              betas=(args.beta1, args.beta2))

model, _, training_dataloader, _ = deepspeed.initialize(
    args=args,
    model=model,
    optimizer=optimizer,
    training_data=train_dataset,
)


grow_time = (args.hidden_size_target-args.original_hidden_size)//args.per_step_grow_dim

no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": args.weight_decay,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, weight_decay=args.weight_decay,
                              betas=(args.beta1, args.beta2))
if args.load!=None:
    optimizer.load_state_dict(torch.load(os.path.join(args.load, "optimizer.bin"),map_location=torch.device('cpu')))


def do_eval(model):
    model.eval()
    eval_temp_loss = 0
    num_step = 0
    with torch.no_grad():
        for step, batch in enumerate(eval_data_loader):
            inputs = {}
            inputs["input_ids"] = batch["input_ids"].cuda()
            inputs['attention_mask'] = batch['attention_mask'].cuda()
            inputs["labels"] = batch["labels"].cuda()
            loss = model(**inputs).loss
            eval_temp_loss += loss.cpu().data.item()
            num_step = step + 1

    return eval_temp_loss / num_step


grow_agent = grow_ops(model)

grow_step_count = 0

growed = False
train_loss_collect = []
eval_loss_collect = []

sample_perstep=8*4  #train_micro_batch_size_per_gpu*num_gpu
acc_step=int(args.batch_size/sample_perstep)

optimizer_step=0
for step, batch in enumerate(training_dataloader):
    if (step + 1) % acc_step == 0:
        optimizer_step+=1

    model.train()

    inputs = {}
    inputs["input_ids"] = batch["input_ids"].cuda()
    inputs['attention_mask'] = batch['attention_mask'].cuda()
    inputs["labels"] = batch["labels"].cuda()
    loss = model(**inputs).loss
    train_loss_collect.append(loss.cpu().data.item())
    model.backward(loss)
    model.step()


    if (step+1)%acc_step==0 and args.local_rank==0:
        if args.log_interval > 0:
            if optimizer_step % args.log_interval == 0:
                print(get_time())
                print("step:{} train_loss:{}".format(optimizer_step, train_loss_collect[-1]))

    if (step + 1) % acc_step == 0 and args.local_rank == 0:
        if args.eval_interval>0:
            if optimizer_step % args.eval_interval == 0:
                eval_loss = do_eval(model)
                print("step:{} eval_loss:{}".format(optimizer_step, eval_loss))
                eval_loss_collect.append(eval_loss)

    if (step + 1) % acc_step == 0 and args.local_rank == 0:
        if optimizer_step % args.save_interval==0:
            output_dir = os.path.join(args.output_dir,"checkpoint-{}".format(str(optimizer_step)))
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            torch.save(model.state_dict(), output_dir + "/pytorch_model.bin")

    if (step + 1) % acc_step == 0 and optimizer_step == args.grow_step:
        growed = True
        if args.local_rank == 0:
            del loss
            with torch.no_grad():
                loss = model(**inputs).loss
            print(f"loss before grow hidden size: {loss.cpu().data.item()}")
            print(f"model size: {grow_agent.count_parameters(model)}")
        begin_grow_time = time.time()
        config_up.hidden_size = args.hidden_size_target
        config_up.num_hidden_layers = args.layer_target
        config_up.intermediate_size = args.intermediate_target
        config_up.num_heads = args.head_target

        new_model = new_model(config_up, args)
        grow_agent.set_grow(model, new_model, "hidden_size",
                            args.hidden_size_target, grow_time, args)
        grow_agent.set_grow(model, new_model, "heads",
                            args.head_target, grow_time, args)
        grow_agent.set_grow(model, new_model, "layers",
                            args.layer_target, grow_time, args)
        no_decay = ["bias", "ln_1.weight", "ln_2.weight", "ln_f.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in new_model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": args.weight_decay,
            },
            {
                "params": [p for n, p in new_model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        new_optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate,
                                      weight_decay=args.weight_decay,
                                      betas=(args.beta1, args.beta2))



        grow_agent.mask_to_gpu(new_model)

        grow_agent.grow_opt(model, new_model, optimizer,
                            new_optimizer,args)
        new_model, _, _, _ = deepspeed.initialize(
            args=args,
            model=new_model,
            optimizer=new_optimizer,
        )

        grow_agent.mask_to_gpu(new_model)


        del model, optimizer, loss
        model, optimizer = new_model, new_optimizer

        grow_step_count = 0
        if args.local_rank == 0:
            with torch.no_grad():
                loss = model(**inputs).loss
            print(f"loss after grow hidden size: {loss.cpu().data.item()}")
            print(f"model size: {grow_agent.count_parameters(model)}")
            del loss

    if not grow_agent.available_to_grow and (step + 1) % acc_step == 0:
        if grow_step_count < grow_time:
            grow_agent.increase_mask(model,args.per_step_grow_dim)
            grow_step_count += 1
        else:
            grow_agent.end_grow(model)
            grow_step_count = 0
            if args.local_rank==0:
                print("===growing ended at step "+str(optimizer_step)+"===")
            grow_agent.print_all_masks(model)

if args.local_rank==0:
    loss_dir=args.output_dir
    if not os.path.exists(loss_dir):
        os.makedirs(loss_dir)
    with open(loss_dir+"/train_loss.txt", "a") as f:
        for l in train_loss_collect:
            f.write(str(l) + "\n")
    with open(loss_dir+"/eval_loss.txt", "a") as f:
        for l in eval_loss_collect:
            f.write(str(l) + "\n")
    print("finish {} steps".format((step + 1) // accumulation_steps))

