import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from tqdm import tqdm
import wandb
import clip
import os
import json
import spams

class CLIPModel(nn.Module):
    def __init__(self, clip_model_name="ViT-B/32", device=None):
        super().__init__()
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.clip_model, self.preprocess = clip.load(clip_model_name, device=self.device)
        self.image_encoder = self.clip_model.visual
        self.text_encoder = self.clip_model.transformer
        
    def encode_image(self, x):
        return self.image_encoder(x)
    
    def encode_text(self, text):
        text_tokens = clip.tokenize(text).to(self.device)
        return self.text_encoder(text_tokens)

class Dictionary(nn.Module):
    def __init__(self, dictionary_size, embedding_dim, device=None):
        super().__init__()
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dictionary = nn.Parameter(torch.randn(dictionary_size, embedding_dim, device=self.device))
        # Initialize dictionary with normalized random vectors
        with torch.no_grad():
            self.dictionary.data = nn.functional.normalize(self.dictionary.data, dim=1)
            
    def forward(self, x):
        return self.dictionary

class LinearProbe(nn.Module):
    def __init__(self, input_dim, num_classes, keep_bias=True):
        super().__init__()
        self.classifier = nn.Linear(input_dim, num_classes, bias=keep_bias)
        
    def forward(self, x):
        return self.classifier(x)

class Differentiable_Sparse_Coding(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x_embed, Ds, alpha, tau, gamma, sparsity_level, non_negative_sparse_code, active_set, maxiter, device):
        """Forward pass using SPAMS solver."""
        ctx.save_for_backward(x_embed, Ds)
        ctx.tau = tau

        x_embed_np = x_embed.detach().cpu().numpy()
        Ds_np = Ds.detach().cpu().numpy()

        # Convert to Fortran array format required by SPAMS
        x_embed_ft = np.asfortranarray(x_embed_np.T, dtype=np.float32)
        Ds_ft = np.asfortranarray(Ds_np, dtype=np.float32)

        # Use SPAMS lasso solver
        coef = spams.lasso(
            x_embed_ft,
            D=Ds_ft,
            lambda1=alpha * tau,
            lambda2=alpha * (1.0 - tau),
            L=sparsity_level,
            pos=bool(non_negative_sparse_code),
        )
        coef = np.asarray(coef.todense()).T
        coef = torch.tensor(coef, device=x_embed.device)
        alphas = torch.tensor([alpha] * x_embed.shape[0], device=device, dtype=torch.float32)

        ctx.coef = coef
        ctx.alphas = alphas
        return coef

    @staticmethod
    def backward(ctx, coef_grad):
        """Backward pass computing gradients."""
        z, D = ctx.saved_tensors
        grad_z = torch.zeros_like(z)

        for i in range(len(ctx.coef)):
            try:
                c = ctx.coef[i]
                nonzero_indices = torch.nonzero(c, as_tuple=True)[0]

                if nonzero_indices.numel() == 0:
                    continue

                D_star = D[:, nonzero_indices]
                I = torch.eye(D_star.shape[1], device=D_star.device, dtype=D_star.dtype)
                dloss_dc_star = coef_grad[i, nonzero_indices]

                A = torch.matmul(D_star.T, D_star) + ctx.alphas[i] * (1 - ctx.tau) * I
                b_nz = torch.linalg.solve(A, dloss_dc_star.unsqueeze(-1)).squeeze(-1)

                b = torch.zeros_like(c)
                b[nonzero_indices] = b_nz

                grad_z[i] = torch.matmul(D_star, b_nz)

            except Exception as e:
                print(f"Error during backward pass at index {i}: {e}")

        return grad_z, None, None, None, None, None, None, None, None, None

class SparseCodingModel(nn.Module):
    def __init__(self, input_dim, num_classes, dictionary_size, sparsity_level, keep_bias=True, device=None):
        super().__init__()
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dictionary = Dictionary(dictionary_size, input_dim, device=self.device)
        self.classifier = nn.Linear(dictionary_size, num_classes, bias=keep_bias)
        self.sparsity_level = sparsity_level
        
        # SPAMS solver parameters
        self.alpha = 1e-2  # lasso alpha
        self.tau = 1.0     # elastic net mixing parameter (1.0 = pure lasso)
        self.gamma = 10.0  # factor to divide alpha0 by
        self.non_negative_sparse_code = False
        
    def forward(self, x):
        # Compute sparse codes using SPAMS solver
        sparse_codes = Differentiable_Sparse_Coding.apply(
            x, 
            self.dictionary.dictionary,
            self.alpha,
            self.tau,
            self.gamma,
            self.sparsity_level,
            self.non_negative_sparse_code,
            False,  # active_set
            100,    # maxiter
            self.device
        )
        return self.classifier(sparse_codes)

def load_concept_dictionary(concept_file, clip_model, device=None):
    """Load concepts from a JSON file and encode them using CLIP."""
    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load concepts from JSON file
    with open(concept_file, 'r') as f:
        concepts = json.load(f)
    
    if "concepts" in concepts:
        concepts = concepts["concepts"]
    
    # Encode concepts using CLIP
    concept_embeddings = []
    for synset, concept_list in concepts.items():
        for concept in concept_list:
            text_embedding = clip_model.encode_text([concept])
            concept_embeddings.append(text_embedding)
    
    # Stack all embeddings
    dictionary = torch.cat(concept_embeddings, dim=0)
    
    # Normalize dictionary
    dictionary = nn.functional.normalize(dictionary, dim=1)
    
    return dictionary

def train_model(model, train_loader, val_loader, test_loader, args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # Initialize optimizer - only optimize classifier parameters for sparse coding model
    if args.model_type == 'sparse_coding':
        optimizer = optim.AdamW(model.classifier.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    else:
        if args.optimizer == "adamw":
            optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer == "sgd":
            optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        else:
            raise ValueError(f"Unsupported optimizer: {args.optimizer}")
    
    # Initialize learning rate scheduler
    if args.lr_scheduler == "cosine":
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epochs)
    else:
        scheduler = None
    
    # Initialize loss function
    criterion = nn.CrossEntropyLoss()
    
    # Training loop
    best_val_acc = 0
    best_model_path = None
    
    for epoch in range(args.num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            
            if args.clip_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
                
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()
            
        train_acc = 100. * train_correct / train_total
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()
                
        val_acc = 100. * val_correct / val_total
        
        # Update learning rate
        if scheduler is not None:
            scheduler.step()
            
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            if args.use_model_checkpoint:
                checkpoint_dir = f"checkpoints/{args.model_type}"
                os.makedirs(checkpoint_dir, exist_ok=True)
                best_model_path = f"{checkpoint_dir}/best_model.pt"
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_acc': val_acc,
                }, best_model_path)
        
        # Log metrics
        if args.use_wandb:
            wandb.log({
                'train_loss': train_loss / len(train_loader),
                'train_acc': train_acc,
                'val_loss': val_loss / len(val_loader),
                'val_acc': val_acc,
                'epoch': epoch
            })
            
        print(f'Epoch {epoch}: Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%')
    
    # Test phase
    model.eval()
    test_correct = 0
    test_total = 0
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            test_total += targets.size(0)
            test_correct += predicted.eq(targets).sum().item()
            
    test_acc = 100. * test_correct / test_total
    print(f'Test Accuracy: {test_acc:.2f}%')
    
    return test_acc, model, {'checkpoint_path': best_model_path}

def main():
    # Parse arguments
    import argparse
    parser = argparse.ArgumentParser()
    
    # Model parameters
    parser.add_argument('--model_type', type=str, default='linear_probe', choices=['linear_probe', 'sparse_coding'])
    parser.add_argument('--clip_model_name', type=str, default='ViT-B/32')
    parser.add_argument('--input_dim', type=int, default=512)
    parser.add_argument('--num_classes', type=int, default=10)
    parser.add_argument('--dictionary_size', type=int, default=1000)
    parser.add_argument('--sparsity_level', type=int, default=100)
    parser.add_argument('--keep_bias', type=bool, default=True)
    parser.add_argument('--concept_file', type=str, default=None, help='Path to JSON file containing concepts')
    
    # SPAMS solver parameters
    parser.add_argument('--alpha', type=float, default=1e-2, help='Lasso alpha parameter')
    parser.add_argument('--tau', type=float, default=1.0, help='Elastic net mixing parameter')
    parser.add_argument('--gamma', type=float, default=10.0, help='Factor to divide alpha0 by')
    parser.add_argument('--non_negative_sparse_code', type=bool, default=False, help='Whether to enforce non-negative sparse codes')
    
    # Training parameters
    parser.add_argument('--optimizer', type=str, default='adamw', choices=['adamw', 'sgd'])
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=1e-2)
    parser.add_argument('--num_epochs', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--clip_grad_norm', type=float, default=1.0)
    parser.add_argument('--lr_scheduler', type=str, default='cosine')
    parser.add_argument('--use_model_checkpoint', type=bool, default=True)
    parser.add_argument('--use_wandb', type=bool, default=False)
    
    args = parser.parse_args()
    
    # Initialize wandb if enabled
    if args.use_wandb:
        wandb.init(project="model-training", config=args)
    
    # Initialize CLIP model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    clip_model = CLIPModel(args.clip_model_name, device)
    
    # Load concept dictionary if provided
    if args.concept_file:
        dictionary = load_concept_dictionary(args.concept_file, clip_model, device)
        args.dictionary_size = dictionary.shape[0]
    
    # Create dummy data for demonstration
    # In practice, you would load your actual dataset here
    train_data = torch.randn(1000, args.input_dim)
    train_labels = torch.randint(0, args.num_classes, (1000,))
    val_data = torch.randn(200, args.input_dim)
    val_labels = torch.randint(0, args.num_classes, (200,))
    test_data = torch.randn(200, args.input_dim)
    test_labels = torch.randint(0, args.num_classes, (200,))
    
    # Create data loaders
    train_dataset = TensorDataset(train_data, train_labels)
    val_dataset = TensorDataset(val_data, val_labels)
    test_dataset = TensorDataset(test_data, test_labels)
    
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size)
    
    # Initialize model
    if args.model_type == 'linear_probe':
        model = LinearProbe(args.input_dim, args.num_classes, args.keep_bias)
    else:
        model = SparseCodingModel(args.input_dim, args.num_classes, args.dictionary_size, 
                                args.sparsity_level, args.keep_bias, device)
        if args.concept_file:
            model.dictionary.dictionary.data = dictionary
        # Set SPAMS solver parameters
        model.alpha = args.alpha
        model.tau = args.tau
        model.gamma = args.gamma
        model.non_negative_sparse_code = args.non_negative_sparse_code
    
    # Train model
    test_acc, model, extra_returns = train_model(model, train_loader, val_loader, test_loader, args)
    
    print(f"Training completed. Test accuracy: {test_acc:.2f}%")
    print(f"Best model saved at: {extra_returns['checkpoint_path']}")

if __name__ == "__main__":
    main() 