import os
import gc
import sys
import time
import json
import copy
import random
import argparse
import wandb
import torch
import numpy as np

from tqdm import tqdm
from omegaconf import OmegaConf as om
from Pruner.datasets.train_data import get_c4
from Pruner.mask.model import Masked_Llama
from Pruner.mask.model_opt import Masked_OPT
from Pruner.optim.optimizer import DecoupledAdamW
from Pruner.optim.scheduler import CosineAnnealingWithWarmupScheduler, WarmupScheduler
from Pruner.evaluator.ppl import PPLMetric
from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion,
                              DecoupledLionW, DecoupledLionW_8bit)
from transformers import AutoTokenizer


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
def build_optimizer(model, name, cfg, pretrain):
    """ 
        build optimizer that consists of two groups of parameters:
        - main_model_params: parameters of the main model
        - mask_module_params: parameters of the mask
    """    
    param_groups = []
    if pretrain:
        main_model_params = [p for n, p in model.named_parameters() if "mask" not in n]
        param_groups = [{"params": main_model_params, "lr": cfg.optimizer.lr}]
    mask_module_params = [p for n, p in model.named_parameters() if "mask" in n]
    print('================== PRUNING MASK ==================')
    for n, p in model.named_parameters():
        if "mask" in n:
            print(n)
            print(p.shape)

    mask_lr = cfg.model.mask.mask_lr
    if len(mask_module_params) > 0:
        param_groups.extend([{"params": mask_module_params, "lr": mask_lr}])
    
    for i, group in enumerate(param_groups):
        print(f"Group {i}:", f"{len(group['params'])} tensors", f"{sum(p.numel() for p in group['params'])} params", f"{group['lr']:.2e} lr")
    
    if name == 'decoupled_adamw':
        return DecoupledAdamW(param_groups, **cfg.optimizer)
    elif name == 'decoupled_lionw':
        return DecoupledLionW(param_groups, **cfg.optimizer)
    elif name == 'clip_lion':
        return DecoupledClipLion(param_groups, **cfg.optimizer)
    elif name == 'adalr_lion':
        return DecoupledAdaLRLion(param_groups, **cfg.optimizer)
    elif name == 'decoupled_lionw_8b':
        return DecoupledLionW_8bit(param_groups, **cfg.optimizer)
    elif name == 'adam':
        return torch.optim.Adam(param_groups, **cfg.optimizer)
    else:
        raise ValueError(f'Not sure how to build optimizer: {name}')

def build_scheduler(optimizer, t_max, cfg):
    if cfg.name == 'cosine_with_warmup':
        return CosineAnnealingWithWarmupScheduler(optimizer, t_warmup=cfg.t_warmup, t_max=t_max, alpha_f=cfg.alpha_f)
    if cfg.name == 'cosine':
        return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, t_max, eta_min=1e-5)  
    if cfg.name == 'warmup':
        return WarmupScheduler(optimizer, t_warmup=cfg.t_warmup)

def main(args):
    with open(args.config_path) as f:
        yaml_cfg = om.load(f)

    prune_module = ''
    for i, module in enumerate(yaml_cfg.model.mask.pruning_modules):
        if i < len(yaml_cfg.model.mask.pruning_modules) - 1:
            prune_module += module
            prune_module += '+'
        else:
            prune_module += module
        
    name = f'{yaml_cfg.model.name}_{prune_module}_{yaml_cfg.model.mask.target_sparsity}_{yaml_cfg.model.mask.mask_lr}'
    wandb.init()

    set_random_seed(args.seed)

    if yaml_cfg.model.name == 'llama3_8b':
        from llama3.llama.tokenizer import Tokenizer
        model = Masked_Llama(yaml_cfg.model)
        tokenizer = Tokenizer(model_path=yaml_cfg.model.tokenizer_path)
    if 'llama2' in yaml_cfg.model.name:
        from llama2.llama.tokenizer import Tokenizer
        model = Masked_Llama(yaml_cfg.model)
        tokenizer = Tokenizer(model_path=yaml_cfg.model.tokenizer_path)
    if 'llama3.1' in yaml_cfg.model.name or 'llama3.2' in yaml_cfg.model.name: 
        from llama3_1.api.tokenizer import Tokenizer
        model = Masked_Llama(yaml_cfg.model)
        tokenizer = Tokenizer(model_path=yaml_cfg.model.tokenizer_path)
    if 'opt' in yaml_cfg.model.name:
        model = Masked_OPT(yaml_cfg.model)
        tokenizer = AutoTokenizer.from_pretrained(yaml_cfg.model.file_path)

    for module in yaml_cfg.model.mask.pruning_modules: 
        getattr(model.mask_module.masks, module).score.requires_grad = True
    
    for n, p in model.named_parameters():
        if 'mask' not in n:
            p.requires_grad = False

    model.to(args.device, dtype=torch.bfloat16)
    
    train_loader = get_c4(tokenizer=tokenizer, n_samples=args.samples, seq_len=args.max_seq_len, batch_size=args.batch_size, name=yaml_cfg.model.name.split('_')[0])

    optimizer = build_optimizer(model, yaml_cfg.optimizer.pop("name"), yaml_cfg, args.pretrain)
    scheduler = build_scheduler(optimizer, len(train_loader) * args.training_epochs, yaml_cfg.scheduler)
    model.initialize_score(method=args.initialize_method, tokenizer=tokenizer, device=args.device, model_name=model.cfg.name)

    step = 0
    snapshot_loss = 0.
    print('================== START PRUNING! ==================')
    for epoch in range(args.training_epochs):
        fn_list = []
        grad_list = []
        outdated_masks = []
        outdated_zs = []
        batch_list = []
        for t, x in enumerate(tqdm(train_loader)):
            model.train()
            batch_list.append(x)
            if len(batch_list) == args.instantation_freq:
                for i in range(args.K_inner):
                    outdated_masks_t, outdated_zs_t = model.sim_instantation(device=args.device, test_mask=args.instantation_test_mask, test_batch=batch_list[-1])
                    outdated_masks.append(outdated_masks_t)
                    outdated_zs.append(outdated_zs_t)
                    for batch in batch_list:
                        x = batch.to(args.device)
                        output = model.sim_forward(x, outdated_zs_t)
                        loss = model.loss(output, x)
                        fn_list.append(loss)
                        grad_list.append(output['grads'])

                for i in range(args.instantation_freq):
                    fn_list_t = [fn_list[i + args.instantation_freq * temp]for temp in range(args.K_inner)]
                    grad_list_t = [grad_list[i + args.instantation_freq * temp] for temp in range(args.K_inner)]
                    optimizer.zero_grad(set_to_none=False)
                    model.mask_module.update(t, fn_list_t, grad_list_t, args.K_inner, outdated_masks, outdated_zs)
                    optimizer.step()
                    scheduler.step()
                    target_mask_num = model.mask_module.calculate_target_mask_num(epoch, t, args.training_epochs, len(train_loader))
                    model.mask_module.constrain_score(target_mask_num)
                snapshot_loss += sum(fn_list)/len(fn_list)
                print(f'------ STEP {t} Loss: {sum(fn_list)/len(fn_list)}')
                batch_list = []
                fn_list = []
                grad_list = []
                outdated_masks = []
                outdated_zs = []
            step += 1

            if step % args.snapshot_freq == 0:
                log_dict = {'Snapshot_loss': snapshot_loss / args.snapshot_freq}
                for module in yaml_cfg.model.mask.pruning_modules:
                    log_dict[f'{module}_sparsity'] = torch.sum(getattr(model.mask_module.masks, module).score.data) / getattr(model.mask_module.masks, module).mask_param_num()
                wandb.log(log_dict)
                snapshot_loss = 0.
            
            if args.ppl_during_train and step % args.ppl_freq == 0:
                model.eval()
                ppl = PPLMetric(model.model, tokenizer, ['wikitext2', 'ptb'], mask_module=model.mask_module, seq_len=2048, device=args.device, sparse=False, lm_head=model.lm_head if 'opt' in model.cfg.name else None)
                wandb.log(ppl)
            
    if args.test_before_train:
        print('================== PPL TEST AFTER PRUNING ==================')
        model.eval()
        ppl = PPLMetric(model.model, tokenizer, ['wikitext2', 'ptb'], mask_module=model.mask_module, seq_len=2048, device=args.device, sparse=False, lm_head=model.lm_head if 'opt' in model.cfg.name else None)
        print(f"PPL after Pruning: {ppl}")
    print('================== PRUNING OVER! ==================')

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Pruning LLaMA (huggingface version)')

    parser.add_argument('--device', type=str, default="cuda:3", help='device')
    parser.add_argument('--seed', type=int, default=0, help='seed')
    parser.add_argument('--save_model', action='store_true', help='if save model')

    parser.add_argument('--config_path', type=str, default='configs/llama2/7b.yaml', help='the config file path')
    parser.add_argument('--log_path', type=str, default='save')

    parser.add_argument('--test_before_train', action='store_true', help='whether test before training')

    parser.add_argument('--samples', type=int, default=120000, help='samples for mask training')
    parser.add_argument('--batch_size', type=int, default=8, help='each devices training batch size')
    parser.add_argument('--max_seq_len', type=int, default=128, help='max sequence length input')
    parser.add_argument('--training_epochs', type=int, default=1, help='total training epochs')
    parser.add_argument('--K_inner', type=int, default=1, help='inner loop iterations for sampling Mask')
    parser.add_argument('--snapshot_freq', type=int, default=10, help='steps to save the checkpoint')
    parser.add_argument('--save_freq', type=int, default=500, help='steps to save the checkpoint')

    parser.add_argument('--instantation_model', action='store_true', help='whether use instantation during pruning')
    parser.add_argument('--instantation_freq', type=int, default=500, help='steps to instantiate the dense model')
    parser.add_argument('--instantation_test_mask', action='store_true', help='whether sample some masks to test before instantation')

    parser.add_argument('--initialize_method', type=str, default='mean', help='method to initialize the score')

    parser.add_argument('--ppl_during_train', action='store_true', help='whether test ppl during training')
    parser.add_argument('--ppl_freq', type=int, default=500, help='steps interval to test ppl')

    args = parser.parse_args()

    main(args)
