import argparse
import os
import copy
import torch
import wandb
import random
import numpy as np
from torch import nn
from tqdm import tqdm
from torch.optim import AdamW, SGD
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR, StepLR
from transformers import ViTConfig
from model import ViTGLUForImageClassification
from utils import set_seed, save_checkpoint, remove_old_checkpoints
from calculate import conservation_log, precompute_h0_values
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from datasets import build_dataset
# ---------- Dataset Loader ----------
def data_loader(args):
    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    # dataset_train, args.nb_classes = build_dataset(is_train=False, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)
    sampler_train = torch.utils.data.RandomSampler(dataset_train)
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train, batch_size=args.batch_size,
        num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True
    )
    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, sampler=sampler_val, batch_size=args.batch_size,
        num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False
    )
    return data_loader_train, data_loader_val

def main(args: argparse.Namespace):
    # Init wandb
    wandb.init(
        project=args.wandb_project,
        entity=args.wandb_entity,
        group=args.wandb_group,
        id=args.wandb_id,
        name=f"{args.opt}-sheduler{args.lr_scheduler}-lr{args.lr}-warmup{args.warmup_lr}-hidden{args.hidden_size}-decay{args.weight_decay}"
        f"-warmepoch{args.warmup_epochs}-momentum{args.momentum}-type{args.mlp_type}-seed{args.seed}",
        save_code=True,
    )
    save_path = os.path.join(args.save_dir,wandb.run.name)
    wandb.config.update(vars(args))
    set_seed(args.seed)
    save_path = os.path.join(args.save_dir, wandb.run.name)
    os.makedirs(save_path, exist_ok=True)
    # === Model and Data  ===
    config = ViTConfig(
        hidden_size=args.hidden_size,
        num_hidden_layers=args.num_hidden_layer,
        num_attention_heads=args.num_attention_heads,
        intermediate_size=args.intermediate_size,
        image_size=args.input_size,
        patch_size=args.patch_size,
        num_channels=args.num_channels,
        num_labels=args.num_classes,
        qkv_bias=args.qkv_bias,
        hidden_act=args.hidden_act,
        hidden_dropout_prob=0.0,
        attention_probs_dropout_prob=0.0,
        mlp_type=args.mlp_type,
        mlp_bias=args.mlp_bias,
    )
    config.save_pretrained(save_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ViTGLUForImageClassification(config).to(device)
    train_loader, val_loader = data_loader(args)
    # Optimizer and Scheduler
    start_factor = args.warmup_lr / args.lr
    if(args.opt == "sgd"):
        optimizer = SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
    elif(args.opt == "adamw"):
        optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    if(args.lr_scheduler == "linear"):
        scheduler = LinearLR(optimizer, start_factor=start_factor, end_factor=1.0, total_iters=args.warmup_epochs)
    elif(args.lr_scheduler == "step"):
        scheduler = StepLR(optimizer, step_size = args.epochs//3, gamma=0.1)
    elif(args.lr_scheduler == "cosine"):
        scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=args.eta_min)
    elif(args.lr_scheduler == "cosine_warm"):
        warmup_scheduler = LinearLR(optimizer, start_factor=start_factor, end_factor=1.0, total_iters=args.warmup_epochs)
        cosine_scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs - args.warmup_epochs, eta_min=args.eta_min)
        scheduler = SequentialLR(optimizer,schedulers=[warmup_scheduler, cosine_scheduler],milestones=[args.warmup_epochs])
    curr_epoch, best_acc = 1, 0.0
    if(args.restore_path):
        checkpoint = torch.load(args.restore_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        curr_epoch = checkpoint['epoch'] + 1  # resume after saved epoch
        best_acc = checkpoint['val_acc']
        current_lr = checkpoint['learning_rate']   
    # Training Setting
    criterion = nn.CrossEntropyLoss()
    global_step = (curr_epoch-1)*len(train_loader)
    start_model = copy.deepcopy(model)
    print("Precomputing h0 values...")
    cached_h0 = precompute_h0_values(start_model, config)
    print("h0 values cached successfully!")
    print("Starting training...")
    # conservation_log(global_step=global_step,start_model=start_model,current_model = model,log_type="step",config=config,args=args)
    for epoch in range(curr_epoch, args.epochs+1):
        model.train()
        train_loss, train_acc = 0.0, 0.0
        progress_bar = tqdm(train_loader, desc=f"[Epoch {epoch}/{args.epochs}] Training")
        for batch_idx, (images, labels) in enumerate(progress_bar):
            images, labels = images.to(device), labels.to(device)
            outputs = model(pixel_values=images).logits    
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            global_step += 1
            conservation_log(global_step=global_step,cached_h0=cached_h0,current_model = model,log_type="step",config=config,args=args)
            preds = outputs.argmax(dim=1)
            train_loss_step = loss.item()
            train_acc_step  = (preds == labels).sum().item()*100/ labels.size(0)
            train_loss += train_loss_step
            train_acc  += train_acc_step
            if global_step % args.logs_frequency == 0 and global_step > 0:
                wandb.log({"train_loss_step": train_loss_step, "train_acc_step": train_acc_step,}, step=global_step)
            # Update postfix for tqdm
            progress_bar.set_postfix({
                "loss": f"{train_loss_step:.4f}", "acc":f"{train_acc_step:.4f}","lr": f"{optimizer.param_groups[0]['lr']:.6f}"
            })
        # conservation_log(global_step=global_step,start_model=start_model,current_model = model,log_type = "epoch",config=config,args=args)
        avg_train_loss = train_loss / len(train_loader)
        avg_train_acc = train_acc / len(train_loader)
        current_lr = optimizer.param_groups[0]["lr"]
        # === Validation ===
        model.eval()
        correct, total = 0, 0
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(pixel_values=images).logits
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        avg_val_acc = correct / total
        avg_val_loss = val_loss / len(val_loader)
        print(f"Epoch {epoch}, Train Loss:{avg_train_loss:.4f}, Train Acc:{avg_train_acc:.2f}, Val Loss:{avg_val_loss:.4f}, Val Acc:{avg_val_acc*100 :.2f}%, LR:{current_lr:.4f}")
        wandb.log({
            "train_loss": avg_train_loss, "train_acc":avg_train_acc, "val_acc": avg_val_acc,"val_loss": avg_val_loss,
            "learning_rate": current_lr, "epoch": epoch}, step=global_step
        )
        # Step the LR scheduler
        scheduler.step()
        # Save last checkpoint
        remove_old_checkpoints(save_path, "last_")
        last_path = os.path.join(save_path, f"last_{epoch}.pt")
        save_checkpoint(last_path, model, optimizer, scheduler, epoch, best_acc, scheduler.get_last_lr()[0])
        # Save best checkpoint if improved
        if avg_val_acc > best_acc:
            best_acc = avg_val_acc
            remove_old_checkpoints(save_path, "best_")
            best_path = os.path.join(save_path, f"best_{epoch}.pt")
            save_checkpoint(best_path, model, optimizer, scheduler, epoch, best_acc, scheduler.get_last_lr()[0])
            print(f"Best model saved at epoch {epoch}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train ViT on CIFAR-10")
    # --- Data Config ---
    parser.add_argument("--data-path", type=str, required=True)
    parser.add_argument('--data-set', default='IMNET', choices=['CIFAR10', 'IMNET', 'INAT', 'INAT19'])
    parser.add_argument("--input-size", type=int, default=224)
    parser.add_argument("--num-channels", type=int, default=3)
    parser.add_argument('--num_workers', type=int, default=8)
    parser.add_argument('--pin-mem', action='store_true')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--color-jitter', type=float, default=0.4)
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1')
    parser.add_argument('--train-interpolation', type=str, default='bicubic')
    parser.add_argument('--reprob', type=float, default=0.25)
    parser.add_argument('--remode', type=str, default='pixel')
    parser.add_argument('--recount', type=int, default=1)
    # Model hyperparameters
    parser.add_argument("--hidden-size", type=int, default=768)
    parser.add_argument("--num-hidden-layer", type=int, default=12)
    parser.add_argument("--intermediate-size", type=int, default=3072)
    parser.add_argument("--num-classes", type=int, default=1000)
    parser.add_argument("--num-attention-heads", type=int, default=12)
    parser.add_argument("--patch-size", type=int, default=16)
    parser.add_argument("--position-embedding", type=str, default=None, choices=["rope", "learnable"])
    parser.add_argument("--mlp-type", type=str, default="mlp", help="Type of MLP to use")
    parser.add_argument("--mlp-bias", action="store_true", help="Use bias in MLP")
    parser.add_argument("--qkv-bias", action="store_true", help="Use bias in QKV")
    parser.add_argument("--hidden-act", type=str, default="silu", help="Hidden activation function")
    # Optimizer settings
    parser.add_argument("--opt", type=str, default="sgd")
    parser.add_argument("--epochs", type=int, default=300)
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--lr", type=float, default=5e-4)
    parser.add_argument("--momentum", type=float, default=0.9)
    parser.add_argument("--weight_decay", type=float, default=1e-5)
    parser.add_argument("--lr-scheduler", type=str, default="cosine", choices=["linear","step", "cosine", "plateau","cosine_warm"])
    parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',help='warmup learning rate (default: 1e-6)')
    parser.add_argument('--eta-min', type=float, default=1e-5, metavar='LR',help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
    parser.add_argument('--warmup-epochs', type=int, default=0, metavar='N',help='epochs to warmup LR, if scheduler supports')
    # Save and Logging
    parser.add_argument("--save-dir", type=str, required=True)
    parser.add_argument("--restore-path", type=str, default=None)
    parser.add_argument("--wandb-project", type=str, default=None)
    parser.add_argument("--wandb-group", type=str, default=None)
    parser.add_argument("--wandb-entity", type=str, default=None, help="wandb entity for logging")
    parser.add_argument("--wandb-id", type=str, default=None, help="wandb id for logging")
    parser.add_argument("--logs-frequency", type=int, default=1, help="Log training loss every N steps")
    args = parser.parse_args()
    main(args)