import argparse
import os
import copy
import torch
import wandb
import random
import numpy as np
from torch import nn
from tqdm import tqdm
from torch.optim import AdamW, SGD
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR, StepLR
from transformers import ViTConfig
from model import ViTGLUForImageClassification
from utils import set_seed, save_checkpoint, remove_old_checkpoints
from calculate import conservation_log, precompute_h0_values  # Updated import
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


def data_loader(args):
    if args.data_set == 'CIFAR10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(args.input_size, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
        ])

        transform_test = transforms.Compose([
            transforms.Resize(args.input_size),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
        ])

        trainset = datasets.CIFAR10(root=args.data_path, train=True, download=True, transform=transform_train)
        train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

        testset = datasets.CIFAR10(root=args.data_path, train=False, download=True, transform=transform_test)
        val_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
        
        return train_loader, val_loader
    
    elif args.data_set == 'MNIST':
        transform_train = transforms.Compose([
            transforms.Resize(args.input_size),
            transforms.RandomCrop(args.input_size, padding=4),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])

        transform_test = transforms.Compose([
            transforms.Resize(args.input_size),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])

        trainset = datasets.MNIST(root=args.data_path, train=True, download=True, transform=transform_train)
        train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

        testset = datasets.MNIST(root=args.data_path, train=False, download=True, transform=transform_test)
        val_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
        
        return train_loader, val_loader
    
    raise NotImplementedError(f"Dataset {args.data_set} is not implemented. Available: CIFAR10, MNIST")

def main(args: argparse.Namespace):
    # Init wandb
    wandb.init(
        project=args.wandb_project,
        entity=args.wandb_entity,
        group=args.wandb_group,
        id=args.wandb_id,
        name=f"{args.opt}-lr{args.lr}-warmup{args.warmup_lr}-sheduler{args.lr_scheduler}-decay{args.weight_decay}"
        f"-warmepoch{args.warmup_epochs}-momentum{args.momentum}-type{args.mlp_type}-seed{args.seed}",
        save_code=True,
    )
    save_path = os.path.join(args.save_dir, wandb.run.name)
    wandb.config.update(vars(args))
    set_seed(args.seed)
    os.makedirs(save_path, exist_ok=True)
    
    # === Model and Data ===
    config = ViTConfig(
        hidden_size=args.hidden_size,
        num_hidden_layers=args.num_hidden_layer,
        num_attention_heads=args.num_attention_heads,
        intermediate_size=args.intermediate_size,
        image_size=args.input_size,
        patch_size=args.patch_size,
        num_channels=args.num_channels,
        num_labels=args.num_classes,
        qkv_bias=args.qkv_bias,
        hidden_act=args.hidden_act,
        hidden_dropout_prob=0.0,
        attention_probs_dropout_prob=0.0,
        mlp_type=args.mlp_type,
        mlp_bias=args.mlp_bias,
    )
    config.save_pretrained(save_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ViTGLUForImageClassification(config).to(device)
    train_loader, val_loader = data_loader(args)
    
    # Optimizer and Scheduler
    start_factor = args.warmup_lr / args.lr
    if args.opt == "sgd":
        optimizer = SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
    elif args.opt == "adamw":
        optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    
    if args.lr_scheduler == "linear":
        scheduler = LinearLR(optimizer, start_factor=start_factor, end_factor=1.0, total_iters=args.warmup_epochs)
    elif args.lr_scheduler == "step":
        scheduler = StepLR(optimizer, step_size=args.epochs//3, gamma=0.1)
    elif args.lr_scheduler == "cosine":
        scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=args.eta_min)
    elif args.lr_scheduler == "cosine_warm":
        warmup_scheduler = LinearLR(optimizer, start_factor=start_factor, end_factor=1.0, total_iters=args.warmup_epochs)
        cosine_scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs - args.warmup_epochs, eta_min=args.eta_min)
        scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[args.warmup_epochs])
    
    curr_epoch, best_acc = 1, 0.0
    if args.restore_path:
        checkpoint = torch.load(args.restore_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        curr_epoch = checkpoint['epoch'] + 1
        best_acc = checkpoint['val_acc']
        current_lr = checkpoint['learning_rate']
    
    # Training Setting
    criterion = nn.CrossEntropyLoss()
    num_steps_per_epoch = (len(train_loader) + args.grad_accumulation_steps - 1) // args.grad_accumulation_steps
    global_step = (curr_epoch - 1) * num_steps_per_epoch
    
    # === OPTIMIZATION: Precompute h0 values once ===
    start_model = copy.deepcopy(model)
    print("Precomputing h0 values...")
    cached_h0 = precompute_h0_values(start_model, config)
    print("h0 values cached successfully!")
    
    print("Starting training...")
    conservation_log(
        global_step=global_step,
        cached_h0=cached_h0,
        current_model=model,
        log_type="step",
        config=config,
        args=args
    )
    for epoch in range(curr_epoch, args.epochs + 1):
        model.train()
        train_loss, train_acc = 0.0, 0.0
        progress_bar = tqdm(train_loader, desc=f"[Epoch {epoch}/{args.epochs}] Training")
        
        for batch_idx, (images, labels) in enumerate(progress_bar):
            images, labels = images.to(device), labels.to(device)
            outputs = model(pixel_values=images).logits
            loss = criterion(outputs, labels)
            loss = loss / args.grad_accumulation_steps
            loss.backward()
            
            if (batch_idx + 1) % args.grad_accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
                
                # Use cached h0 values instead of start_model
                conservation_log(
                    global_step=global_step,
                    cached_h0=cached_h0,
                    current_model=model,
                    log_type="step",
                    config=config,
                    args=args
                )
            
            preds = outputs.argmax(dim=1)
            train_loss_step = loss.item() * args.grad_accumulation_steps
            train_acc_step = (preds == labels).sum().item() * 100 / labels.size(0)
            train_loss += train_loss_step
            train_acc += train_acc_step
            
            if global_step % args.logs_frequency == 0 and global_step > 0 and ((batch_idx + 1) % args.grad_accumulation_steps == 0 or (batch_idx + 1) == len(train_loader)):
                wandb.log({"train_loss_step": train_loss_step, "train_acc_step": train_acc_step}, step=global_step)
            
            progress_bar.set_postfix({
                "loss": f"{train_loss_step:.4f}",
                "acc": f"{train_acc_step:.4f}",
                "lr": f"{optimizer.param_groups[0]['lr']:.6f}"
            })
        
        # Use cached h0 for epoch logging
        # conservation_log(
        #     global_step=global_step,
        #     cached_h0=cached_h0,
        #     current_model=model,
        #     log_type="epoch",
        #     config=config,
        #     args=args
        # )
        
        avg_train_loss = train_loss / len(train_loader)
        avg_train_acc = train_acc / len(train_loader)
        current_lr = optimizer.param_groups[0]["lr"]
        
        # === Validation ===
        model.eval()
        correct, total = 0, 0
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(pixel_values=images).logits
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        
        avg_val_acc = correct / total
        avg_val_loss = val_loss / len(val_loader)
        print(f"Epoch {epoch}, Train Loss:{avg_train_loss:.4f}, Train Acc:{avg_train_acc:.2f}, "
              f"Val Loss:{avg_val_loss:.4f}, Val Acc:{avg_val_acc*100:.2f}%, LR:{current_lr:.4f}")
        
        wandb.log({
            "train_loss": avg_train_loss,
            "train_acc": avg_train_acc,
            "val_acc": avg_val_acc,
            "val_loss": avg_val_loss,
            "learning_rate": current_lr,
            "epoch": epoch
        }, step=global_step)
        
        scheduler.step()
        
        # Save last checkpoint
        remove_old_checkpoints(save_path, "last_")
        last_path = os.path.join(save_path, f"last_{epoch}.pt")
        save_checkpoint(last_path, model, optimizer, scheduler, epoch, best_acc, scheduler.get_last_lr()[0])
        
        # Save best checkpoint
        if avg_val_acc > best_acc:
            best_acc = avg_val_acc
            remove_old_checkpoints(save_path, "best_")
            best_path = os.path.join(save_path, f"best_{epoch}.pt")
            save_checkpoint(best_path, model, optimizer, scheduler, epoch, best_acc, scheduler.get_last_lr()[0])
            print(f"Best model saved at epoch {epoch}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train ViT on CIFAR-10")
    # --- Data Config ---
    parser.add_argument("--data-path", type=str, required=True)
    parser.add_argument('--data-set', default='IMNET', choices=['CIFAR10', 'MNIST', 'IMNET', 'INAT', 'INAT19'])
    parser.add_argument("--input-size", type=int, default=224)
    parser.add_argument("--num-channels", type=int, default=3)
    parser.add_argument('--num_workers', type=int, default=12)
    parser.add_argument('--pin-mem', action='store_true')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--color-jitter', type=float, default=0.4)
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1')
    parser.add_argument('--train-interpolation', type=str, default='bicubic')
    parser.add_argument('--reprob', type=float, default=0.25)
    parser.add_argument('--remode', type=str, default='pixel')
    parser.add_argument('--recount', type=int, default=1)
    # Model hyperparameters
    parser.add_argument("--hidden-size", type=int, default=768)
    parser.add_argument("--num-hidden-layer", type=int, default=12)
    parser.add_argument("--intermediate-size", type=int, default=3072)
    parser.add_argument("--num-classes", type=int, default=1000)
    parser.add_argument("--num-attention-heads", type=int, default=12)
    parser.add_argument("--patch-size", type=int, default=16)
    parser.add_argument("--position-embedding", type=str, default=None, choices=["rope", "learnable"])
    parser.add_argument("--mlp-type", type=str, default="mlp", help="Type of MLP to use")
    parser.add_argument("--mlp-bias", action="store_true", help="Use bias in MLP")
    parser.add_argument("--qkv-bias", action="store_true", help="Use bias in QKV")
    parser.add_argument("--hidden-act", type=str, default="silu", help="Hidden activation function")
    # Optimizer settings
    parser.add_argument("--opt", type=str, default="sgd")
    parser.add_argument("--epochs", type=int, default=300)
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--lr", type=float, default=5e-4)
    parser.add_argument("--momentum", type=float, default=0.9)
    parser.add_argument("--weight_decay", type=float, default=1e-5)
    parser.add_argument("--lr-scheduler", type=str, default="cosine", choices=["linear", "step", "cosine", "plateau", "cosine_warm"])
    parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', help='warmup learning rate (default: 1e-6)')
    parser.add_argument('--eta-min', type=float, default=1e-5, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
    parser.add_argument('--warmup-epochs', type=int, default=0, metavar='N', help='epochs to warmup LR, if scheduler supports')
    # Save and Logging
    parser.add_argument("--grad-accumulation-steps", type=int, default=1, help="Number of steps to accumulate gradients")
    parser.add_argument("--save-dir", type=str, required=True)
    parser.add_argument("--restore-path", type=str, default=None)
    parser.add_argument("--wandb-project", type=str, default=None)
    parser.add_argument("--wandb-group", type=str, default=None)
    parser.add_argument("--wandb-entity", type=str, default=None, help="wandb entity for logging")
    parser.add_argument("--wandb-id", type=str, default=None, help="wandb id for logging")
    parser.add_argument("--logs-frequency", type=int, default=50, help="Log training loss every N steps")
    args = parser.parse_args()
    main(args)
