import os
import sys
import uuid
import glob
import time
from dataclasses import dataclass

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
import torch._inductor.config as config
from models.model_time_token import GPT, GPTConfig
from optimizer import Muon
import wandb
import numpy as np

# FM utils
from omegaconf import OmegaConf
config = OmegaConf.load(sys.argv[1])

import importlib

def load_class(module_name, class_name):
    module = importlib.import_module(module_name)
    return getattr(module, class_name)
# FM utils
fm_type = load_class("FM_utils",config.fm.type)
fm = fm_type(**dict(config.fm_config))


# -----------------------------------------------------------------------------
# Our own simple Distributed Data Loader

import boto3
from botocore.config import Config as boto_config

config_boto = boto_config(
    read_timeout=2000,
    connect_timeout=2000,
    retries={"max_attempts": 13}
)
s3 = None


def _peek_data_shard(filename,iters_try=1000):
    # only reads the header, returns header data
    try:
        with open(filename, "rb") as f:
            # first read the header, which is 256 int32 integers (4 bytes each)
            header = np.frombuffer(f.read(256*4), dtype=np.int32)
        assert header[0] == 20240520
        assert header[1] == 1
        ntok = header[2] # number of tokens (claimed)
        return ntok # for 
    except:
        pass
    for i in range(iters_try):
        try:
            f = s3.get_object(Bucket='fineweb',Key=filename.split("/")[-1])['Body']
            header = np.frombuffer(f.read(256*4), dtype=np.int32)
            assert header[0] == 20240520
            assert header[1] == 1
            ntok = header[2] # number of tokens (claimed)
            return ntok # for now just return the number of tokens
        except:
            pass
    assert False, "Iters to read file is out"

def _load_data_shard(filename,iters_try=1000):
    try:
        with open(filename, "rb") as f:
            # first read the header, which is 256 int32 integers (4 bytes each)
            header = np.frombuffer(f.read(256*4), dtype=np.int32)
            assert header[0] == 20240520, "magic number mismatch in the data .bin file"
            assert header[1] == 1, "unsupported version"
            ntok = header[2] # number of tokens (claimed)
            # the rest of it are tokens, stored as uint16
            if num_vocab < 2**16:
                dtype = np.uint16
            else:
                dtype = np.uint32
            tokens = np.frombuffer(f.read(), dtype=dtype)
        assert len(tokens) == ntok, "number of tokens read does not match header?"
        return tokens
    except:
        pass
    for i in range(iters_try):
        try:
            f = s3.get_object(Bucket='fineweb',Key=filename.split("/")[-1])['Body']
            # first read the header, which is 256 int32 integers (4 bytes each)
            header = np.frombuffer(f.read(256*4), dtype=np.int32)
            assert header[0] == 20240520, "magic number mismatch in the data .bin file"
            assert header[1] == 1, "unsupported version"
            ntok = header[2] # number of tokens (claimed)
            # the rest of it are tokens, stored as uint16
            if num_vocab < 2**16:
                dtype = np.uint16
            else:
                dtype = np.uint32
            tokens = np.frombuffer(f.read(), dtype=dtype)
            assert len(tokens) == ntok, "number of tokens read does not match header?"
            return tokens
        except:
            pass
    assert False, "Iters to read file is out"

class DistributedDataLoader:
    def __init__(self, vocab_size, filename_pattern, B, T, process_rank, num_processes,num_tokens_done=0, condition=False):
        self.process_rank = process_rank
        self.num_processes = num_processes
        self.B = B
        self.T = T
        if condition:
            self.T *= 2
            self.T = self.T + 1
        self.vocab_size = vocab_size
        # filename pattern up to last /
        self.path = "/".join(filename_pattern.split("/")[:-1])
        # glob files that match the pattern
        self.files = sorted(glob.glob(filename_pattern))
        assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"

        # load and validate all data shards, count number of tokens in total
        ntok_total = 0
        for fname in self.files:
            shard_ntok = _peek_data_shard(fname)
            assert shard_ntok >= num_processes * B * T + 1
            ntok_total += int(shard_ntok)
        self.ntok_total = ntok_total
        self.num_tokens_done = num_tokens_done
        self.condition = condition
        # kick things off
        self.reset()

    def reset(self):
        self.current_shard = 0 + self.num_tokens_done // (10**8)
        self.current_position = self.process_rank * self.B * self.T + self.num_tokens_done % (10**8)
        self.tokens = _load_data_shard(self.files[self.current_shard])

    def advance(self): # advance to next data shard
        self.current_shard = (self.current_shard + 1) % len(self.files)
        self.current_position = self.process_rank * self.B * self.T
        self.tokens = _load_data_shard(self.files[self.current_shard])

    def next_batch(self):
        B = self.B
        T = self.T
        buf = self.tokens[self.current_position : self.current_position+B*T]
        buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
        x1 = buf.view(B, T) # targets
        idx, x1 = x1[:,0], x1[:,1:]
        mask = None
        if self.condition:
            x1, mask = x1.chunk(2,dim=1)
            mask_cond = (mask > 0.5).cuda()
            mask_ques = (mask == 2).cuda()
        x1 = x1.cuda()
        x0 = fm.sampler_0(x1,self.vocab_size)
        t,xt = fm.interpolate(x0,x1,mask=mask_cond)
        test_list = torch.load(os.path.join(self.path, f"{idx[0].item()}.pt"))
        # advance current position and load next shard if necessary
        self.current_position += B * T * self.num_processes
        if self.current_position + (B * T * self.num_processes+1) > len(self.tokens):
            self.advance()
        return xt, x1, mask_cond, mask_ques, test_list


# -

# -----------------------------------------------------------------------------
# int main

class Hyperparameters:
    input_bin = config.data.input_bin
    input_val_bin = config.data.input_val_bin
    data_condition = config.data.condition
    input_bin_pretrain = config.data.input_bin_pretrain
    input_val_bin_pretrain = config.data.input_val_bin_pretrain
    num_tokens_to_train = config.training_config.num_tokens_to_train * 10**9
    batch_size = config.training_config.batch_size 
    device_batch_size = config.training_config.device_batch_size
    sequence_length = config.data.sequence_length
    embed_learning_rate = config.optimizer.embed_learning_rate
    muon_learning_rate = config.optimizer.muon_learning_rate
    warmup_iters = config.optimizer.warmup_iters
    warmdown_iters = config.optimizer.warmdown_iters 
    weight_decay = config.optimizer.weight_decay
    val_loss_every = config.training_config.val_loss_every
    save_every = config.training_config.save_every
    project_name = config.training_config.project_name
    run_name = config.training_config.run_name
    checkpoint = config.inference.checkpoint
    pretrain = config.training_config.pretrain
args = Hyperparameters()
postfix = sys.argv[2]
if args.checkpoint is not None:
    ckpt_path = args.checkpoint
else:
    assert False, "Define path to model checkpoint"
# set up DDP (distributed data parallel). torchrun sets this env variable
assert torch.cuda.is_available()
ddp_rank = 0
ddp_world_size = 1
device = f'cuda:{ddp_rank}'
print(f"using device: {device}")
#master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
master_process = True
# convenience variables
num_vocab = config.model.vocab_size
B, T = config.inference.B_data, args.sequence_length
B_sub = config.inference.B_sub_data


# train_loader = DistributedDataLoader(num_vocab, args.input_bin, B, T, ddp_rank, ddp_world_size,condition=args.data_condition)
val_loader = DistributedDataLoader(num_vocab, args.input_val_bin.replace("postfix",postfix), B, T, ddp_rank, ddp_world_size,condition=args.data_condition)

# +
import tiktoken
class CustomTokenizer(object):
    def __init__(self,name,special_tokens={}):
        self.name = name
        self.tokenizer = tiktoken.get_encoding(name)
        self.ind_to_str = {i:self.tokenizer.decode([i]) for i in range(self.tokenizer.n_vocab)}
        for key,value in special_tokens.items():
            self.ind_to_str[key] = value
        self.n_vocab = self.tokenizer.n_vocab + len(special_tokens)
    def encode(self,text):
        return self.tokenizer.encode(text)
    def decode(self,ind):
        return "".join([self.ind_to_str[idx] for idx in ind])

tokenizer = CustomTokenizer("gpt2",special_tokens={50257:"<|pad|>"})

# +
# there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977.
# this originates from Karpathy's experiments.
model_type = load_class(f"models.{config.model_type}","GPT")
model = model_type(config.model)

ckpt = torch.load(ckpt_path,map_location="cpu")["model"]
ckpt = {name.replace("_orig_mod.",""):value for name,value in ckpt.items()}
model.load_state_dict(ckpt)
if master_process:
    print("Model loaded from checkpoint")
    
model = model.cuda()
if hasattr(config, "coordinate_descent_tuning"):
    config.coordinate_descent_tuning = True # suggested by @Chillee
# model = torch.compile(model)
for param in model.parameters():
    param.requires_grad = False
# here we wrap model into DDP container
ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)

model.eval()
def model_forward(t,x):
    with ctx:
        return model(t,x,return_logits=True)[0]
from tqdm import tqdm


inference_mixed = torch.compile(fm.inference_mixed,dynamic=False)
# inference_mixed = fm.inference_mixed
N_samples = min(config.inference.N_samples,int(val_loader.ntok_total / (2*T + 1)))
print(N_samples)
os.makedirs(f"inference/{ckpt_path.split('/')[-2]}_{postfix}",exist_ok=True)
def process_code(code):
    # number of <|endoftext|> in the code equals to the twice the number of masked lines in the code plus 1
    parts = code.split("<|endoftext|>")
    code_processed = parts[0]
    parts = parts[1:]
    # print(parts)
    # assert len(parts) % 2 == 0, "number of <|endoftext|> in the code is not equal to the twice the number of masked lines in the code plus 1"
    n_masked = len(parts) // 2
    for i in range(n_masked):
        code_processed += parts[-n_masked+i] + parts[i]
    return code_processed.replace("<|pad|>","").replace("<|endoftext|>","")

n_vocab = tokenizer.n_vocab
for i in range(N_samples // B):
    x0, x1, mask_cond, mask_ques, test_list = val_loader.next_batch()
    mask_cond, x1 = mask_cond.repeat(B_sub,1).bool(), x1.repeat(B_sub,1)
    assert B == 1, "conditional inference not implemented for B > 1"
    cond_size = mask_ques[0].int().sum(dim=-1).item()
    fm.update_N(torch.logical_not(mask_cond[0]).int().sum(dim=-1).item())
    x0 = fm.sampler_0(x1,val_loader.vocab_size)
    _,x0 = fm.interpolate(x0,x1,mask=mask_cond,t=torch.zeros(x0.size(0),device=device))
    
    x1_hat = inference_mixed(model_forward,x0,mask=mask_cond.float())

    if config.fm.type == "GPT":
        process = process_code
    else:
        process = lambda x: x.replace("<|pad|>","").replace("<|endoftext|>","")
    if config.fm.type == "GPT":
        pred = [process(tokenizer.decode(x1_hat[i,cond_size:].tolist())) for i in range(x1_hat.size(0))]
    else:
        pred = [process(tokenizer.decode(x1_hat[i,cond_size:,:n_vocab].argmax(dim=-1).tolist())) for i in range(x1_hat.size(0))]
    tar = [process(tokenizer.decode(x1[0,cond_size:].tolist()))]
    
    ques = [tokenizer.decode(x1[0,:cond_size].tolist())]
    pred = [p for p in pred if len(p) > 0]
    if len(pred) == 0:
        continue
    for j in range(B):
        torch.save({"pred":pred[j*B_sub:(j+1)*B_sub],
                    "ques":ques[j],
                    "tar":tar[j],
                    "test_list":test_list,
                   },
                   f"inference/{ckpt_path.split('/')[-2]}_{postfix}/{i*B+j}.pt"
                  )


