import os
import sys
import uuid
import glob
import time
from dataclasses import dataclass

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
import torch._inductor.config as config
import numpy as np

# FM utils
from omegaconf import OmegaConf
config = OmegaConf.load(sys.argv[1])

import importlib

def load_class(module_name, class_name):
    module = importlib.import_module(module_name)
    return getattr(module, class_name)
# FM utils
fm_type = load_class("FM_utils",config.fm.type)
fm = fm_type(**dict(config.fm_config))



# -----------------------------------------------------------------------------
# Our own simple Distributed Data Loader

import boto3
from botocore.config import Config as boto_config

config_boto = boto_config(
    read_timeout=2000,
    connect_timeout=2000,
    retries={"max_attempts": 13}
)
session = boto3.session.Session(profile_name='default')
s3 = None


def _peek_data_shard(filename,iters_try=1000):
    # only reads the header, returns header data
    try:
        with open(filename, "rb") as f:
            # first read the header, which is 256 int32 integers (4 bytes each)
            header = np.frombuffer(f.read(256*4), dtype=np.int32)
        assert header[0] == 20240520
        assert header[1] == 1
        ntok = header[2] # number of tokens (claimed)
        return ntok # for 
    except:
        pass
    for i in range(iters_try):
        try:
            f = s3.get_object(Bucket='fineweb',Key=filename.split("/")[-1])['Body']
            header = np.frombuffer(f.read(256*4), dtype=np.int32)
            assert header[0] == 20240520
            assert header[1] == 1
            ntok = header[2] # number of tokens (claimed)
            return ntok # for now just return the number of tokens
        except:
            pass
    assert False, "Iters to read file is out"

def _load_data_shard(filename,iters_try=1000):
    try:
        with open(filename, "rb") as f:
            # first read the header, which is 256 int32 integers (4 bytes each)
            header = np.frombuffer(f.read(256*4), dtype=np.int32)
            assert header[0] == 20240520, "magic number mismatch in the data .bin file"
            assert header[1] == 1, "unsupported version"
            ntok = header[2] # number of tokens (claimed)
            # the rest of it are tokens, stored as uint16
            if num_vocab < 2**16:
                dtype = np.uint16
            else:
                dtype = np.uint32
            tokens = np.frombuffer(f.read(), dtype=dtype)
        assert len(tokens) == ntok, "number of tokens read does not match header?"
        return tokens
    except:
        pass
    for i in range(iters_try):
        try:
            f = s3.get_object(Bucket='fineweb',Key=filename.split("/")[-1])['Body']
            # first read the header, which is 256 int32 integers (4 bytes each)
            header = np.frombuffer(f.read(256*4), dtype=np.int32)
            assert header[0] == 20240520, "magic number mismatch in the data .bin file"
            assert header[1] == 1, "unsupported version"
            ntok = header[2] # number of tokens (claimed)
            # the rest of it are tokens, stored as uint16
            if num_vocab < 2**16:
                dtype = np.uint16
            else:
                dtype = np.uint32
            tokens = np.frombuffer(f.read(), dtype=dtype)
            assert len(tokens) == ntok, "number of tokens read does not match header?"
            return tokens
        except:
            pass
    assert False, "Iters to read file is out"

class DistributedDataLoader:
    def __init__(self, vocab_size, filename_pattern, B, T, process_rank, num_processes,num_tokens_done=0, condition=False):
        self.process_rank = process_rank
        self.num_processes = num_processes
        self.B = B
        self.T = T
        if condition:
            self.T *= 2
        self.vocab_size = vocab_size

        # glob files that match the pattern
        self.files = sorted(glob.glob(filename_pattern))
        assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"

        # load and validate all data shards, count number of tokens in total
        ntok_total = 0
        for fname in self.files:
            shard_ntok = _peek_data_shard(fname)
            assert shard_ntok >= num_processes * B * T + 1
            ntok_total += int(shard_ntok)
        self.ntok_total = ntok_total
        self.num_tokens_done = num_tokens_done
        self.condition = condition
        # kick things off
        self.reset()

    def reset(self):
        self.current_shard = 0 + self.num_tokens_done // (10**8)
        self.current_position = self.process_rank * self.B * self.T + self.num_tokens_done % (10**8)
        self.tokens = _load_data_shard(self.files[self.current_shard])

    def advance(self): # advance to next data shard
        self.current_shard = (self.current_shard + 1) % len(self.files)
        self.current_position = self.process_rank * self.B * self.T
        self.tokens = _load_data_shard(self.files[self.current_shard])

    def next_batch(self):
        B = self.B
        T = self.T
        buf = self.tokens[self.current_position : self.current_position+B*T]
        buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
        x1 = buf.view(B, T) # targets
        mask = None
        if self.condition:
            x1, mask = x1.chunk(2,dim=1)
            mask = mask.bool().cuda()
        x1 = x1.cuda()
        x0 = fm.sampler_0(x1,self.vocab_size)
        batch = fm.interpolate(x0,x1,mask=mask)
        if len(batch) == 3:
            t,lt,xt = batch
        else:
            t,xt = batch
        # advance current position and load next shard if necessary
        self.current_position += B * T * self.num_processes
        if self.current_position + (B * T * self.num_processes+1) > len(self.tokens):
            self.advance()
        return t.cuda(), xt, x1, mask


# -

# -----------------------------------------------------------------------------
# int main

class Hyperparameters:
    input_bin = config.data.input_bin
    input_val_bin = config.data.input_val_bin
    data_condition = config.data.condition
    input_bin_pretrain = config.data.input_bin_pretrain
    input_val_bin_pretrain = config.data.input_val_bin_pretrain
    num_tokens_to_train = config.training_config.num_tokens_to_train * 10**9
    batch_size = config.training_config.batch_size 
    device_batch_size = config.training_config.device_batch_size
    sequence_length = config.data.sequence_length
    embed_learning_rate = config.optimizer.embed_learning_rate
    muon_learning_rate = config.optimizer.muon_learning_rate
    warmup_iters = config.optimizer.warmup_iters
    warmdown_iters = config.optimizer.warmdown_iters 
    weight_decay = config.optimizer.weight_decay
    val_loss_every = config.training_config.val_loss_every
    save_every = config.training_config.save_every
    project_name = config.training_config.project_name
    run_name = config.training_config.run_name
    checkpoint = config.inference.checkpoint
    pretrain = config.training_config.pretrain
args = Hyperparameters()

if args.checkpoint is not None:
    ckpt_path = args.checkpoint
else:
    assert False, "Define path to model checkpoint"
# set up DDP (distributed data parallel). torchrun sets this env variable
assert torch.cuda.is_available()
ddp_rank = 0
ddp_world_size = 1
device = f'cuda:{ddp_rank}'
print(f"using device: {device}")
#master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
master_process = True
# convenience variables
num_vocab = config.model.vocab_size
B, T = config.inference.B_data, args.sequence_length
B_sub = config.inference.B_sub_data


train_loader = DistributedDataLoader(num_vocab, args.input_bin, B, T, ddp_rank, ddp_world_size,condition=args.data_condition)
val_loader = DistributedDataLoader(num_vocab, args.input_val_bin, B, T, ddp_rank, ddp_world_size,condition=args.data_condition)

# +
import tiktoken

tokenizer = tiktoken.get_encoding("gpt2")

# +
# there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977.
# this originates from Karpathy's experiments.
model_type = load_class(f"models.{config.model_type}","GPT")
model = model_type(config.model)

ckpt = torch.load(ckpt_path,map_location="cpu")["model"]
ckpt = {name.replace("_orig_mod.",""):value for name,value in ckpt.items()}
model.load_state_dict(ckpt)
if master_process:
    print("Model loaded from checkpoint")
    
model = model.cuda()
if hasattr(config, "coordinate_descent_tuning"):
    config.coordinate_descent_tuning = True # suggested by @Chillee
# model = torch.compile(model)
for param in model.parameters():
    param.requires_grad = False
# here we wrap model into DDP container
ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)

model.eval()
def model_forward(*inputs):
    with ctx:
        return model(*inputs)[0]
from tqdm import tqdm


inference_mixed = torch.compile(fm.inference_mixed,dynamic=False)
# inference_mixed = fm.inference_mixed
N_samples = config.inference.N_samples
os.makedirs(f"inference/{ckpt_path.split('/')[-2]}_{fm.t_split}",exist_ok=True)
for i in range(N_samples // B):
    if args.data_condition:
        _, x0, x1, mask = val_loader.next_batch()
        mask, x1 = mask.repeat(B_sub,1).bool(), x1.repeat(B_sub,1)
        assert B == 1, "conditional inference not implemented for B > 1"
    else:
        _, x0, x1, mask = val_loader.next_batch()
    
    x0 = fm.sampler_0(x1,val_loader.vocab_size)
    batch = fm.interpolate(x0,x1,mask=mask,t=torch.zeros(x0.size(0),device=device))
    if len(batch) == 3:
        _,lt,x0 = batch
    else:
        _,x0 = batch
    
    x1_hat = inference_mixed(model_forward,x0,mask=None if mask is None else mask.float())

    if args.data_condition:
        cond_size = mask[0].int().sum(dim=-1).item()
        cond_hat = x1_hat[0,:cond_size,:tokenizer.n_vocab].argmax(dim=-1)
        tar_hat = x1[0,cond_size:]
        if len(x1_hat.shape) == 3:
            pred_hat = x1_hat[:,cond_size:,:tokenizer.n_vocab].argmax(dim=-1)
        else:
            pred_hat = x1_hat[:,cond_size:]
    else:
        if len(x1_hat.shape) == 3:
            pred_hat = x1_hat[...,:tokenizer.n_vocab].argmax(dim=-1)
        else:
            pred_hat = x1_hat
        cond_hat = None
        tar_hat = None
    pred = [tokenizer.decode(pred_hat[i].tolist()).replace("<|endoftext|>","") for i in range(x1_hat.size(0))]
    pred = [p for p in pred if len(p) > 0]
    ques = [tokenizer.decode(cond_hat.tolist())] if cond_hat is not None else [None] * len(pred)
    tar = [tokenizer.decode(tar_hat.tolist()).replace("<|endoftext|>","")] if tar_hat is not None else [None] * len(pred)
    if len(pred) == 0:
        continue
    for j in range(B):
        torch.save({"pred":pred[j*B_sub:(j+1)*B_sub],
                    "ques":ques[j],
                    "tar":tar[j]
                   },
                   f"inference/{ckpt_path.split('/')[-2]}_{fm.t_split}/{i*B+j}.pt"
                  )


