import argparse
import datetime
import os
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
from transformers import GPT2TokenizerFast,GPT2LMHeadModel
import torch.nn as nn
import sampling
import utils
from load_model import load_model,load_model_RADD
import time


class CustomDDP(nn.parallel.DistributedDataParallel):
    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)

def _run(rank, args):
    if rank == 0:
        logger = utils.get_logger(os.path.join(args.work_dir, f"language_generation_{args.steps}_{args.method}_comparison_ppl_logs"))
        os.makedirs(f'{args.work_dir}/{args.steps}_{args.method}', exist_ok=True)
    def mprint(msg):
        if rank == 0:
            logger.info(msg)

    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")

    mprint("================================")

    if device.type == "cuda":
        mprint("Found {} CUDA devices.".format(torch.cuda.device_count()))
        for i in range(torch.cuda.device_count()):
            props = torch.cuda.get_device_properties(i)
            mprint(
                "{} \t Memory: {:.2f}GB".format(
                    props.name, props.total_memory / (1024 ** 3)
                )
            )
    else:
        mprint("WARNING: Using device {}".format(device))
    mprint(f"Found {os.cpu_count()} total number of CPUs.")

    mprint("================================")

    args_dict = vars(args)
    for arg_name, arg_value in args_dict.items():
        mprint(f"{arg_name}: {arg_value}")
    
    if args.model_path == 'louaaron/sedd-small':
        model, graph, noise = load_model(args.model_path, device)
    else:
        model, graph, noise = load_model_RADD(args.model_path, device)
    # 在创建DDP模型时
    model = CustomDDP(model, device_ids=[rank], static_graph=True)
    noise = DDP(noise, device_ids=[rank], static_graph=True)
    tokenizer = GPT2TokenizerFast.from_pretrained('assets/gpt2-large')
    gpt2_model = GPT2LMHeadModel.from_pretrained('assets/gpt2-large').to(device).eval()
    sampling_fn = sampling.get_pc_sampler(
        graph, noise, (args.batch_size, args.length), args.method, args.steps, device=device
    )

    total_ppl = torch.tensor(0.).to(device)
    total_time = torch.tensor(0.).to(device) 
    mprint("================================")


    for iter_idx in tqdm(range(args.times)):
        time_start = time.time()
        samples = sampling_fn(model)
        text_samples = tokenizer.batch_decode(samples)
        time_end = time.time()
        total_time += time_end - time_start
        for batch_idx, text_sample in enumerate(text_samples):
            with open(f"{args.work_dir}/{args.steps}_{args.method}/sample_{rank*args.times * args.batch_size  + iter_idx * args.batch_size + batch_idx}.txt","w") as file:
                file.write(text_sample)
            
        with torch.no_grad():
            loss, logits= gpt2_model(samples,labels = samples)[:2]
            logits = logits.transpose(-1,-2)
            results = F.cross_entropy(logits[..., :-1], samples[...,1:],reduction = 'none').mean(dim = -1).exp().mean()
        total_ppl += results

    total_ppl /= args.times
    total_time /= args.times

    dist.all_reduce(total_ppl)
    dist.all_reduce(total_time)
    total_time /= args.ngpus
    total_time /= args.batch_size
    total_ppl /= args.ngpus
    
    mprint("============================================================================")
    mprint(f'Evaluation PPL: {total_ppl}')
    mprint(f"Total update count: {sampling.total_update_cnt/args.times/args.batch_size}")
    mprint(f"Total time: {total_time}")
    mprint("============================================================================\n\n")


def run(rank, args, port):
    def setup(rank, world_size, port):
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = str(port)

        # Initialize the process group
        dist.init_process_group(
            "nccl", rank=rank, world_size=world_size, timeout=datetime.timedelta(minutes=30)
        )


    def cleanup():
        dist.destroy_process_group()


    try:
        setup(rank, args.ngpus, port)
        _run(rank, args)
    finally:
        cleanup()


def main():
    parser = argparse.ArgumentParser(description="Generate some samples")
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--length", type=int, default=1024)
    # parser.add_argument("--model_path", type=str, default="louaaron/sedd-small")
    parser.add_argument("--model_path", type=str, default="assets/radd_v0")

    parser.add_argument("--ngpus", type=int, default=8)
    parser.add_argument("--steps", type=int, default=4096)
    parser.add_argument("--times", type=int, default=16)
    parser.add_argument("--work_dir", type=str, default="./radd_v0_generation_logs")
    # parser.add_argument("--work_dir", type=str, default="./sedd_small_fp32_generation_logs")

    parser.add_argument("--method", type=str, default="cached_euler")

    
    args = parser.parse_args()

    port =  int(np.random.randint(10000, 20000))
    # make the work_dir
    os.makedirs(args.work_dir, exist_ok=True)
    logger = utils.get_logger(os.path.join(args.work_dir, f"language_generation_{args.steps}_{args.method}_comparison_ppl_logs"))

    try:
        mp.set_start_method("forkserver")
        mp.spawn(run, args=(args, port), nprocs=args.ngpus, join=True)
    except Exception as e:
        logger.critical(e, exc_info=True)
    

if __name__=="__main__":
    main()