# =============================================================================
# CONFIDENTIAL - FOR REVIEW ONLY
# This code is submitted as supplementary material for paper review.
# DO NOT DISTRIBUTE - Pending patent application.
# =============================================================================

import os
import sys
import warnings
import random
import datetime
import argparse
from contextlib import contextmanager

import numpy as np
import torch
torch.set_float32_matmul_precision('high')
from torch import optim
from torch.cuda.amp import autocast
import einops

from generator import multi_tree_layers, gen_streach_matrix2, log_prob_normal
from adapters.gpt import get_model_with_parallel_input
import wandb


import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def ddp_is_enabled(ddp: bool) -> bool:
    return ddp and int(os.environ.get("WORLD_SIZE", "1")) > 1

def ddp_setup(backend: str = "nccl"):
    if dist.is_initialized():
        return
    dist.init_process_group(backend=backend)

def ddp_cleanup():
    if dist.is_initialized():
        dist.barrier()
        dist.destroy_process_group()

def get_local_rank() -> int:
    return int(os.environ.get("LOCAL_RANK", "0"))

def get_global_rank() -> int:
    return int(os.environ.get("RANK", "0"))

def get_world_size() -> int:
    return int(os.environ.get("WORLD_SIZE", "1"))

def is_master() -> bool:
    return get_global_rank() == 0


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def kl_divergence(mu1, sigma1, mu2, sigma2):
    return torch.log(sigma2 / sigma1) + (sigma1 ** 2 + (mu1 - mu2) ** 2) / (2 * sigma2 ** 2) - 0.5


def get_inputs(
    batch_size: int = 1,
    num_layers: int = None,
    seq_len: int = 1024,
    num_parallel_input: int = 1,
    device: torch.device = None,
):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    min_num_layers = int(np.log2(seq_len + 1)) // 2
    if num_layers is None or num_layers < min_num_layers:
        num_layers = min_num_layers
    
    tree_op = gen_streach_matrix2(1., 1., 1, 0.2).to(device) * np.sqrt(4 / 3.2)
    
    x, x_mu, x_sigma, _, _ = multi_tree_layers(
        tree_op,
        m=3,
        n=4,
        b=batch_size * num_parallel_input * 3 ** (num_layers - 1),
        num_layers=num_layers,
        device=device,
        return_more=False
    )
    
    x = einops.rearrange(x[:, :seq_len - 1], '(b d) s -> b s d', b=batch_size, d=num_parallel_input)
    x = torch.cat([torch.zeros(batch_size, 1, num_parallel_input, device=device), x], dim=1)
    x_mu = einops.rearrange(x_mu[:, 0:seq_len], '(b d) s -> b s d', b=batch_size, d=num_parallel_input)
    x_sigma = einops.rearrange(x_sigma[:, 0:seq_len], '(b d) s -> b s d', b=batch_size, d=num_parallel_input)
    
    return x, x_mu, x_sigma


@torch.no_grad()
def evaluate(model, num_batches: int, curr_step: int, **input_config):
    model.eval()
    total_loss = 0.
    token_wise_loss = 0.
    
    for _ in range(num_batches):
        inputs, mu_true, sigma_true = get_inputs(**input_config)
        with autocast(device_type=next(model.parameters()).device.type, dtype=torch.bfloat16):
            mu, sigma = model(inputs)
            loss = kl_divergence(mu_true, sigma_true, mu, sigma).mean(dim=(0, 2))
        token_wise_loss += loss.cpu().numpy()
        total_loss += loss.mean().item()
    
    total_loss /= num_batches
    token_wise_loss /= num_batches
    
    print(f"Step {curr_step}: Evaluation Loss: {total_loss:.4f}")
    
    if is_master():
        wandb.log({"eval_loss": total_loss}, step=curr_step)
    
    return total_loss, token_wise_loss.tolist()


def train(
    model_name: str = "openai-community/gpt2",
    seq_len: int = 1024,
    batch_size: int = 4,
    num_steps: int = 10000,
    learning_rate: float = 1e-4,
    weight_decay: float = 0.01,
    warmup_steps: int = 500,
    eval_every: int = 500,
    num_parallel_input: int = 1,
    gradient_accumulation: int = 1,
    seed: int = 42,
    use_ddp: bool = True,
    wandb_project: str = "lcube",
    wandb_run_name: str = None,
    save_dir: str = None,
    **model_kwargs,
):
    if ddp_is_enabled(use_ddp):
        ddp_setup()
        device = torch.device(f"cuda:{get_local_rank()}")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    set_seed(seed + get_global_rank())
    
    if is_master():
        wandb.init(
            project=wandb_project,
            name=wandb_run_name or f"{model_name.split('/')[-1]}_seq{seq_len}",
            config={
                "model_name": model_name,
                "seq_len": seq_len,
                "batch_size": batch_size,
                "learning_rate": learning_rate,
                "num_steps": num_steps,
            }
        )
    
    print(f"Loading model: {model_name}")
    model_kwargs.setdefault("n_positions", seq_len)
    model_kwargs.setdefault("max_position_embeddings", seq_len)
    
    model, config = get_model_with_parallel_input(
        model_name,
        pretrained=False,
        num_parallel_input=num_parallel_input,
        **model_kwargs
    )
    model = model.to(device)
    
    if ddp_is_enabled(use_ddp):
        model = DDP(model, device_ids=[get_local_rank()])
    
    optimizer = optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay,
    )
    
    def lr_schedule(step):
        if step < warmup_steps:
            return step / warmup_steps
        progress = (step - warmup_steps) / (num_steps - warmup_steps)
        return 0.5 * (1 + np.cos(np.pi * progress))
    
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
    
    input_config = {
        "batch_size": batch_size,
        "seq_len": seq_len,
        "num_parallel_input": num_parallel_input,
        "device": device,
    }
    
    model.train()
    optimizer.zero_grad()
    
    for step in range(1, num_steps + 1):
        inputs, mu_true, sigma_true = get_inputs(**input_config)
        
        with autocast(device_type=device.type, dtype=torch.bfloat16):
            mu, sigma = model(inputs)
            loss = kl_divergence(mu_true, sigma_true, mu, sigma).mean()
            loss = loss / gradient_accumulation
        
        loss.backward()
        
        if step % gradient_accumulation == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        if step % 100 == 0 and is_master():
            print(f"Step {step}/{num_steps}, Loss: {loss.item() * gradient_accumulation:.4f}, LR: {scheduler.get_last_lr()[0]:.2e}")
            if is_master():
                wandb.log({"train_loss": loss.item() * gradient_accumulation, "lr": scheduler.get_last_lr()[0]}, step=step)
        
        if step % eval_every == 0:
            evaluate(model, num_batches=10, curr_step=step, **input_config)
            model.train()
        
        if save_dir and step % (eval_every * 5) == 0 and is_master():
            os.makedirs(save_dir, exist_ok=True)
            torch.save({
                "step": step,
                "model_state_dict": model.module.state_dict() if ddp_is_enabled(use_ddp) else model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
            }, os.path.join(save_dir, f"checkpoint_{step}.pt"))
    
    evaluate(model, num_batches=50, curr_step=num_steps, **input_config)
    
    if ddp_is_enabled(use_ddp):
        ddp_cleanup()
    
    if is_master():
        wandb.finish()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train on L-CUBE benchmark")
    parser.add_argument("--model_name", type=str, default="openai-community/gpt2")
    parser.add_argument("--seq_len", type=int, default=1024)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--num_steps", type=int, default=10000)
    parser.add_argument("--learning_rate", type=float, default=1e-4)
    parser.add_argument("--warmup_steps", type=int, default=500)
    parser.add_argument("--eval_every", type=int, default=500)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--save_dir", type=str, default=None)
    parser.add_argument("--wandb_project", type=str, default="lcube")
    
    args = parser.parse_args()
    train(**vars(args))

