import torch
from any_precision import AnyPrecisionForCausalLM_train_whole3456
from transformers import AutoTokenizer, LlamaForCausalLM
from torch.utils.data import DataLoader
from tqdm import tqdm
from datasets import load_dataset
import threading
import argparse
import gc
import random
import copy
import os
import math
import torch.optim as optim
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--lr", type=float, default=1e-2)
parser.add_argument("--alpha", type=float, default=1e+0)
parser.add_argument("--epoch", type=int, default=5)
parser.add_argument("--targ_bits", type=float, default=3.5)
parser.add_argument("--min_prec", type=int, default=3)
parser.add_argument("--max_prec", type=int, default=6)
parser.add_argument("--force_high_prec", type=int)
parser.add_argument("--force_low_prec", type=int)
parser.add_argument("--lr_decay", type=float, default=1.0)
parser.add_argument("--alpha_decay", type=float, default=1.0)
parser.add_argument("--init_targ", action="store_true")
parser.add_argument("--maxmem", type=float, default=6.0)

parser.add_argument("--model_path", type=str)
parser.add_argument("--save_path", type=str)
parser.add_argument("--maxmem_dir", type=str)


args = parser.parse_args()

# Llama3-8b
model_path = args.model_path
save_path = args.save_path
name_to_mem = {
    "q_proj" : 4,#4096,
    "k_proj" : 1,#1024,
    "v_proj" : 1,#1024,
    "o_proj" : 4,#4096,
    "gate_proj" : 14,#14336,
    "up_proj" : 14,#14336,
    "down_proj" : 14,#14336,
}

os.makedirs(save_path, exist_ok=True)

context_length = 512
dataset_length = 1000
lr = args.lr
alpha = args.alpha
epoch = args.epoch
targ_bits = args.targ_bits

min_prec = args.min_prec
max_prec = args.max_prec
maxmem = args.maxmem

if targ_bits > maxmem:
    raise RuntimeError(f"targ_bits({targ_bits}) > maxmem({maxmem})")

assert max_prec > min_prec
assert targ_bits != float(max_prec) and targ_bits != float(min_prec)


layer_count = 32

print(f"lr={lr}, alpha={alpha}, epoch={epoch}, prec={min_prec}-{max_prec}, targbits={targ_bits}")

torch.random.manual_seed(0)
random.seed(0)

tokenizer = AutoTokenizer.from_pretrained(model_path)

sa_proj_list = ["q_proj", "k_proj", "v_proj", "o_proj"]
m1_proj_list = ["gate_proj", "up_proj"]
m2_proj_list = ["down_proj"]

def name2module(model, layer, name):
    if name == "q_proj":
        return model.model.model.layers[layer].self_attn.q_proj
    elif name == "k_proj":
        return model.model.model.layers[layer].self_attn.k_proj
    elif name == "v_proj":
        return model.model.model.layers[layer].self_attn.v_proj
    elif name == "o_proj":
        return model.model.model.layers[layer].self_attn.o_proj
    elif name == "gate_proj":
        return model.model.model.layers[layer].mlp.gate_proj
    elif name == "up_proj":
        return model.model.model.layers[layer].mlp.up_proj
    elif name == "down_proj":
        return model.model.model.layers[layer].mlp.down_proj
    else:
        raise RuntimeError("Unknown Module")

th_init_dict = {}
max_mem_list = [6] * (layer_count*len(sa_proj_list+m1_proj_list+m2_proj_list))
if maxmem < 6.0:
    max_mem_list = torch.load(f"{args.maxmem_dir}/maxmem_{maxmem}.pt")

max_mem_dict = {}

module_i = 0
for l in range(layer_count):
    for n in sa_proj_list+m1_proj_list+m2_proj_list:
        th_init_dict[(l,n)] = random.random()
        if args.init_targ:
            if targ_bits < max_mem_list[module_i]:
                th_init_dict[(l,n)] = (targ_bits-min_prec) / (max_mem_list[module_i]-3)
        
        max_mem_dict[(l,n)] = max_mem_list[module_i]
        module_i += 1


model = AnyPrecisionForCausalLM_train_whole3456.from_quantized(model_path, precisions=[p for p in range(min_prec, max_prec+1)], 
                                                     th_init_dict=th_init_dict, max_mem_dict=max_mem_dict)
model = model.cuda()

for p in model.parameters():
    p.requires_grad = False

model.create_th()

dataset = load_dataset('allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')

param_sum = 0

for l in range(layer_count):
    for n in sa_proj_list+m1_proj_list+m2_proj_list:
        param_sum += name_to_mem[n]

possible_prec_list = [(math.ceil(targ_bits)-1, math.floor(targ_bits)+1)]

if args.force_high_prec and args.force_low_prec:
    possible_prec_list = [(args.force_low_prec, args.force_high_prec)]

min_th_arr = []

class Evaluator:
    def __init__(self, dataset, tokenizer, device):
        self.tokenizer = tokenizer
        self.device = device

        dataset = dataset.filter(lambda e: len(e['text']) >= 1) # nan loss를 제외하기 위해 빈 텍스트는 데이터셋에서 제거 

        def tokenize(element):
            outputs = tokenizer(
                element["text"],
                truncation=True,
                max_length=context_length,
                return_overflowing_tokens=True,
                return_length=True,
            )
            input_batch = []
            for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
                if length == context_length:
                    input_batch.append(input_ids)
            return {"input_ids": input_batch}

        tokenized_dataset = dataset.map(
            tokenize, batched=True, remove_columns=dataset.column_names
        )

        training_dataset = tokenized_dataset.select(range(dataset_length))
        training_dataset.set_format(type='torch')
        self.dataloader = DataLoader(training_dataset, batch_size=1)

    def evaluate(self, model):
        global alpha
        sigmoid = torch.nn.Sigmoid()

        # Calculate grad and error
        for (bl, bh) in possible_prec_list:
            optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
            model.set_precision_dual(bl, bh)
            print(f"{bl}&{bh}:")
            for ep in range(epoch):
                if ep > 0:
                    optimizer.param_groups[0]['lr'] /= args.lr_decay
                    alpha /= args.alpha_decay
                with tqdm(self.dataloader, "Run") as tobj:
                    for batch in tobj:
                        optimizer.zero_grad()
                        b = batch['input_ids'].to("cuda:0")
                        loss = model(input_ids=b, labels=b).loss 
                        loss = loss.mean()
                        sum = torch.zeros_like(loss)
                        first_p = None
                        for l in range(layer_count):
                            for n in sa_proj_list+m1_proj_list+m2_proj_list:
                                maxmem_now = max_mem_dict[(l,n)]
                                if maxmem_now == 6:
                                    p = sigmoid(name2module(model, l, n).th).to(loss.device)*3 -1.5
                                    if p < -0.5:
                                        th = -p-0.5
                                        sum += (th) * name_to_mem[n] * 3
                                        sum += (1-th) * name_to_mem[n] * 4
                                    elif p < 0.5:
                                        th = -p+0.5
                                        sum += (th) * name_to_mem[n] * 4
                                        sum += (1-th) * name_to_mem[n] * 5
                                    else:
                                        th = -p+1.5
                                        sum += (th) * name_to_mem[n] * 5
                                        sum += (1-th) * name_to_mem[n] * 6
                                elif maxmem_now == 5:
                                    p = sigmoid(name2module(model, l, n).th).to(loss.device)*2 -1
                                    if p<0:
                                        th = -p
                                        sum += (th) * name_to_mem[n] * 3
                                        sum += (1-th) * name_to_mem[n] * 4
                                    else:
                                        th = -p+1
                                        sum += (th) * name_to_mem[n] * 4
                                        sum += (1-th) * name_to_mem[n] * 5
                                elif maxmem_now == 4:
                                    p = sigmoid(name2module(model, l, n).th).to(loss.device)
                                    th = 1-p
                                    sum += (th) * name_to_mem[n] * 3
                                    sum += (1-th) * name_to_mem[n] * 4
                                elif maxmem_now == 3:
                                    sum += name_to_mem[n] * 3
                                else:
                                    raise RuntimeError(f"Unknown maxmem_now {maxmem_now}")
                                
                                if first_p is None:
                                    first_p = sum.item()/name_to_mem[n]
                        loss += alpha *((sum/param_sum-targ_bits)**2)
                        loss.backward()
                        optimizer.step()
                        tobj.set_description(f"th={(sum/param_sum).item():.4f}, p={first_p:.2f}")

            min_th_arr.clear()
            for l in range(layer_count):
                for n in sa_proj_list+m1_proj_list+m2_proj_list:
                    maxmem_now = max_mem_dict[(l,n)]
                    if maxmem_now == 6:
                        min_th_arr.append(sigmoid(name2module(model, l, n).th).item()*3-1.5)
                    elif maxmem_now == 5:
                        min_th_arr.append(sigmoid(name2module(model, l, n).th).item()*2-1)
                    elif maxmem_now == 4:
                        min_th_arr.append(sigmoid(name2module(model, l, n).th).item())
                    elif maxmem_now == 3:
                        min_th_arr.append(-1)
                    else:
                        raise RuntimeError(f"Unknown maxmem_now {maxmem_now}")
            
            # Reset th
            for l in range(layer_count):
                for n in sa_proj_list+m1_proj_list+m2_proj_list:
                    name2module(model, l, n).th.requires_grad = False
                    rand = random.random()
                    name2module(model, l, n).th[()] = math.log(rand / (1 - rand))
                    name2module(model, l, n).th.requires_grad = True
            
evaluator = Evaluator(dataset, tokenizer, "cuda")
evaluator.evaluate(model)

save_str = f"{save_path}/whole_3456_max{maxmem}_{min_prec}b-{max_prec}b_th_pb_train_{args.lr}_{args.lr_decay}d_{args.alpha}_{args.alpha_decay}ad_{args.epoch}ep_targ{args.targ_bits}b_{'init_' if args.init_targ else ''}_0-{dataset_length}_adam.pt"

if args.force_high_prec and args.force_low_prec:
        save_str = f"{save_path}/whole_3456_max{maxmem}_forced_{args.force_low_prec}b_{args.force_high_prec}b_th_pb_train_{args.lr}_{args.lr_decay}d_{args.alpha}_{args.alpha_decay}ad_{args.epoch}ep_targ{args.targ_bits}b_{'init_' if args.init_targ else ''}0-{dataset_length}_adam.pt"

print(f"Saving to {save_str}")
torch.save((min_th_arr, max_mem_dict), save_str)
