import model
import torch
import dataset
from tqdm import tqdm
import wandb
import torch.nn.functional as F
import os
import argparse
import datetime
import numpy as np
import matplotlib.pyplot as plt
from plotting import plot_parameters, plot_training_curves


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--lr", type=float, required=True)
    parser.add_argument("--gd_lr", type=float, default=0.05, help="Learning rate for gradient descent in data generation")
    parser.add_argument("--gpu", type=int, default=0)
    parser.add_argument("--construction", type=int, default=1)
    parser.add_argument("--optim", type=str, default="adam")
    parser.add_argument("--epochs", type=int, default=1000)
    parser.add_argument("--d", type=int, default=5, help="Dimension of the data")
    parser.add_argument("--T", type=int, default=100, help="Number of training samples")
    parser.add_argument("--num_datasets", type=int, default=200, help="Number of independent datasets to train on")
    parser.add_argument("--batch_size", type=int, default=128, help="Number of steps to accumulate gradients over before backprop")
    parser.add_argument("--num_eval_datasets", type=int, default=10, help="Number of evaluation datasets to sample and average over during evaluation")
    return parser.parse_args()


def setup_experiment(args):
    """Setup experiment directory and wandb."""
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    exp_name = f"exp_{timestamp}_lr{args.lr}_{args.optim}_ep{args.epochs}_d{args.d}_T{args.T}_nds{args.num_datasets}_bs{args.batch_size}_neval{args.num_eval_datasets}"
    exp_dir = os.path.join("experiments", exp_name)
    os.makedirs(exp_dir, exist_ok=True)
    
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    wandb.init(name=exp_name)
    
    return exp_dir


def create_model_and_optimizer(args, device):
    """Create model and optimizer."""
    model_instance = model.CoT(data_num=args.T, d=args.d).to(device)
    
    if args.optim == "adam":
        optimizer = torch.optim.Adam(model_instance.parameters(), lr=args.lr)
    elif args.optim == "sgd":
        optimizer = torch.optim.SGD(model_instance.parameters(), lr=args.lr)
    else:
        raise ValueError(f"Unknown optimizer: {args.optim}")
    
    scheduler = None
    
    return model_instance, optimizer, scheduler


def generate_single_dataset(d, T, device, gd_lr):
    """Generate a single training dataset."""
    X, y, w, gd, max_iter, _ = dataset.data_generation(d=d, n=T, lr=gd_lr)
    return X.to(device), y.to(device), w.to(device), gd.to(device), max_iter




def compute_loss(model_instance, X, y, w, gd_gt, d, T, device, gd_lr):
    """Compute loss for a single data sample and return predicted weights."""
    max_step = w.shape[1] if len(w.shape) > 1 else 1
    
    Z = torch.cat((X, y.T))
    ww = torch.cat((w, torch.ones((1, max_step), device=device)), dim=0)
    train_data = torch.zeros((2 * (d + 1), T + max_step ), device=device)
    train_data[: d + 1, :T] = Z
    train_data[d + 1 :, T  :] = ww
    
    gd = []
    for i in range(0, max_step-1):
        gd.append(model_instance(train_data[:, : T + i ]))
    gd = torch.stack(gd, dim=-1).squeeze()
    loss = torch.sum(torch.linalg.vector_norm(gd-gd_gt, dim=0)) / (max_step-1)
    
    w_predicted = [w[:, 0]]  # Start with initial weight w_0
    for i in range(gd.shape[1] if len(gd.shape) > 1 else 1):
        grad = gd[:, i] if len(gd.shape) > 1 else gd
        w_next = w_predicted[-1] - gd_lr * grad
        w_predicted.append(w_next)
    
    return loss, w_predicted




def compute_single_step_loss(model_instance, X, y, w, gd_gt, step_idx, d, T, device):
    """Compute loss for a single step prediction without backprop."""
    Z = torch.cat((X, y.T))
    # w has shape [d, max_step], we want w_0, w_1, ..., w_step_idx (step_idx+1 columns)
    w_partial = w[:, :step_idx]  # [d, step_idx+1]
    
    bias_row = torch.ones((1, step_idx ), device=device)  # [1, step_idx+1]
    ww = torch.cat((w_partial, bias_row), dim=0)  # [d+1, step_idx+1]
    
    train_data = torch.zeros((2 * (d + 1), T + step_idx  ), device=device)
    train_data[:d+1, :T] = Z
    train_data[d+1:, T:T+step_idx] = ww
    # Predict next step
    predicted_w = model_instance(train_data)
    target_w = gd_gt[:, step_idx-1]
    
    loss = torch.linalg.vector_norm(predicted_w - target_w)
    
    return loss


def train_batch_steps(model_instance, optimizer, batch_data, d, T, device):
    """Train on a batch of steps from different datasets."""
    optimizer.zero_grad()
    
    total_loss = 0.0
    batch_size = len(batch_data)
    for X, y, w, gd_gt, step_idx in batch_data:
        total_loss += compute_single_step_loss(model_instance, X, y, w, gd_gt, step_idx, d, T, device)
        
    total_loss = total_loss / batch_size
    total_loss.backward()    
    optimizer.step()    
    return total_loss.item()



def evaluate_model(model_instance, d, T, device, gd_lr, num_eval_datasets=10):
    """Evaluate model on multiple sampled evaluation datasets and return average metrics."""
    model_instance.eval()
    
    total_loss = 0.0
    total_residual_norm = 0.0
    
    with torch.no_grad():
        for _ in range(num_eval_datasets):
            # Generate a new evaluation dataset
            X_eval, y_eval, w_eval, gd_eval, _ = generate_single_dataset(d, T, device, gd_lr)
            
            # compute_loss now always returns both loss and predicted weights
            loss, w_predicted = compute_loss(model_instance, X_eval, y_eval, w_eval, gd_eval, d, T, device, gd_lr)
            
            # Take the final predicted weight and compute residual norm
            w_predicted_final = w_predicted[-1]
            pred_residual = X_eval.T @ w_predicted_final.unsqueeze(1) - y_eval
            pred_residual_norm = torch.linalg.vector_norm(pred_residual).item()
            
            total_loss += loss.item()
            total_residual_norm += pred_residual_norm
    
    model_instance.train()
    
    # Return average metrics
    avg_loss = total_loss / num_eval_datasets
    avg_residual_norm = total_residual_norm / num_eval_datasets
    
    return avg_loss, avg_residual_norm


def plot_ground_truth_convergence(X_eval, y_eval, w_eval, exp_dir):
    """Create and save ground truth convergence plot (like in dataset.py)."""
    # Calculate residual norm at each step: ||X.T @ w - y||
    residual_norms = []
    for i in range(w_eval.shape[-1]):  # w has shape [d, num_iterations]
        w_i = w_eval[:, i].unsqueeze(1)
        residual = X_eval.T @ w_i - y_eval
        residual_norms.append((residual**2).mean().item())
    
    # Create the plot
    plt.figure(figsize=(10, 6))
    plt.plot(range(len(residual_norms)), residual_norms, 'b-', linewidth=2, marker='o', markersize=4)
    plt.xlabel('Gradient Steps')
    plt.ylabel('||X.T @ w - y||')
    plt.title('Ground Truth Gradient Descent Convergence')
    plt.grid(True, alpha=0.3)
    plt.yscale('log')  # Use log scale for better visualization
    plt.tight_layout()
    
    # Save the plot
    plot_path = os.path.join(exp_dir, 'ground_truth_convergence.png')
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    # Log to wandb as an image
    wandb.log({"ground_truth_convergence": wandb.Image(plot_path)})
    
    print(f"Ground truth convergence plot saved to: {plot_path}")


def save_model(model_instance, optimizer, epoch, loss, exp_dir):
    """Save model checkpoint."""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model_instance.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    torch.save(checkpoint, os.path.join(exp_dir, f'model_epoch_{epoch}.pth'))
    torch.save(checkpoint, os.path.join(exp_dir, 'model_latest.pth'))


def main():
    args = parse_args()
    exp_dir = setup_experiment(args)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model_instance, optimizer, _ = create_model_and_optimizer(args, device)
    
    # Plot initial parameters
    plot_parameters(model_instance, exp_dir, args.lr, args.optim, epoch=0)
    
    # Generate single evaluation dataset
    X_eval, y_eval, w_eval, gd_eval, max_iter_eval = generate_single_dataset(args.d, args.T, device, args.gd_lr)
    
    # Create and save ground truth convergence plot (one-time)
    plot_ground_truth_convergence(X_eval, y_eval, w_eval, exp_dir)
    
    torch.save({
        'eval_data': {
            'X': X_eval.cpu(),
            'y': y_eval.cpu(),
            'w': w_eval.cpu(), 
            'gd': gd_eval.cpu()
        }
    }, os.path.join(exp_dir, 'eval_dataset.pth'))
    
    print(f"Generating {args.num_datasets} training datasets...")
    training_datasets = []
    for i in range(args.num_datasets):
        X_train, y_train, w_train, gd_train, max_iter_train = generate_single_dataset(args.d, args.T, device, args.gd_lr)
        training_datasets.append((X_train, y_train, w_train, gd_train, max_iter_train))

    
    train_losses = []
    eval_losses = []
    
   
    total_steps_per_epoch = sum(max_iter - 1 for _, _, _, _, max_iter in training_datasets)
    total_training_steps = (total_steps_per_epoch * args.epochs + args.batch_size - 1) // args.batch_size
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer, T_max=total_training_steps, eta_min=0.005
    )
    
    tq = tqdm(total=total_training_steps, desc="Training Batches")
    global_step = 0
    batch_count = 0

    all_training_pairs = []
    for dataset_idx in range(args.num_datasets):
        _, _, _, _, max_iter = training_datasets[dataset_idx]
        for step_idx in range(1, max_iter):
            all_training_pairs.append((dataset_idx, step_idx))
        
    for epoch in range(args.epochs):
        # Generate all (dataset_idx, step_idx) pairs and shuffle them
 
        # Shuffle all pairs completely
        torch.manual_seed(epoch)
        shuffled_indices = torch.randperm(len(all_training_pairs))
        all_training_pairs = [all_training_pairs[i] for i in shuffled_indices]
        
        # Process in batches
        for batch_start in range(0, len(all_training_pairs), args.batch_size):
            batch_end = min(batch_start + args.batch_size, len(all_training_pairs))
            batch_pairs = all_training_pairs[batch_start:batch_end]
            
            # Collect batch data
            batch_data = []
            batch_dataset_step_updates = []
            
            for dataset_idx, step_idx in batch_pairs:
                X_train, y_train, w_train, gd_train, max_iter = training_datasets[dataset_idx]
                batch_data.append((X_train, y_train, w_train, gd_train, step_idx))
                batch_dataset_step_updates.append((dataset_idx, step_idx))
            
            if not batch_data:
                break
            
            # Train on the batch
            train_loss = train_batch_steps(model_instance, optimizer, batch_data, args.d, args.T, device)
            train_losses.append(train_loss)
            
            # Periodic evaluation (every few batches)
            if batch_count % max(1, 10 // args.batch_size) == 0:
                eval_loss, pred_residual_norm = evaluate_model(model_instance, args.d, args.T, device, args.gd_lr, args.num_eval_datasets)
                eval_losses.append(eval_loss)
            else:
                eval_loss = eval_losses[-1] if eval_losses else 0.0
                pred_residual_norm = None
            
            scheduler.step()
            global_step += len(batch_data)  # Increment by actual batch size
            batch_count += 1
            
            dataset_info = f"{len(set(update[0] for update in batch_dataset_step_updates))} datasets"
            step_range = f"steps {min(update[1] for update in batch_dataset_step_updates)}-{max(update[1] for update in batch_dataset_step_updates)}"
            
            tq.set_postfix(
                epoch=f"{epoch+1}/{args.epochs}",
                batch_size=len(batch_data),
                datasets=dataset_info,
                steps=step_range,
                train_loss=f"{train_loss:.4f}",
                eval_loss=f"{eval_loss:.4f}"
            )
            tq.update(1)
            
            log_dict = {
                "train_loss": train_loss,
                "eval_loss": eval_loss,
                "lr": scheduler.get_last_lr()[0],
                "batch_size": len(batch_data),
                "epoch": epoch,
                "global_step": global_step,
                "batch_count": batch_count
            }
            
            if pred_residual_norm is not None:
                log_dict["residual_norm/predicted"] = pred_residual_norm
            
            wandb.log(log_dict)
    
    tq.close()
    
    save_model(model_instance, optimizer, args.epochs, train_losses[-1], exp_dir)
    plot_training_curves(train_losses, eval_losses, exp_dir)
    plot_parameters(model_instance, exp_dir, args.lr, args.optim, epoch=args.epochs)
    
    print(f"Experiment completed. Results saved in: {exp_dir}")
    print(f"Trained on {args.num_datasets} independent datasets")
    print(f"Total training batches: {len(train_losses)}")
    print(f"Batch size: {args.batch_size}")
    print(f"Evaluation averaged over {args.num_eval_datasets} datasets per evaluation")
    print(f"Final eval loss: {eval_losses[-1]:.4f}")


if __name__ == "__main__":
    main()