import os
import sys
import copy
import time
import tqdm
import numpy as np
import pickle
import matplotlib.pyplot as plt
from torch.func import vmap

import torch
import torch.nn.functional as F

import clip.clip as clip

from src.args import parse_arguments
from src.datasets.common import get_dataloader, maybe_dictionarize
from src.models.eval import evaluate
from src.models.modeling import ClassificationHead, ImageEncoder, ImageClassifier
from src.models.utils import cosine_lr, torch_load, LabelSmoothing, label_for_samples,adjust_lambda_reg,pca

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
from poly.wd_regularization_torch import (
    polynomial_regularization,
    precompute_chebyshev_matrix,
)

import src.datasets as datasets


def plot_training_curves(args, train_accs, test_accs, train_losses, test_losses, reg_terms, save_path='training_plots'):
    """Plot training curves"""
    epochs = range(1, len(train_accs) + 1)
    
    # Create first plot: Training and Test Accuracy
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_accs, 'b-', label='Training Accuracy')
    plt.plot(epochs, test_accs, 'r-', label='Test Accuracy')
    plt.title('Training and Test Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    
    # Create second plot: Loss and Regularization Term
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_losses, 'g-', label='Training Loss (w/o reg)')
    plt.plot(epochs, test_losses, 'orange', label='Test Loss')
    
    # If there are regularization terms, add them
    if any(reg_terms):
        plt.plot(epochs, reg_terms, 'm--', label='Regularization Term')
    
    plt.title('Training and Test Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    # Ensure save directory exists
    os.makedirs(save_path, exist_ok=True)
    exp_name = args.exp_name if args.exp_name else "finetune"
    plt.savefig(os.path.join(save_path, f'{exp_name}.png'))
    plt.close()
    
    # Save data for later analysis
    data = {
        'train_accs': train_accs,
        'test_accs': test_accs,
        'train_losses': train_losses,
        'test_losses': test_losses,
        'reg_terms': reg_terms
    }
    
    with open(os.path.join(save_path, f'{exp_name}_data.pkl'), 'wb') as f:
        pickle.dump(data, f)


def finetune(args):
    assert args.load is not None, "Please provide the patch to a checkpoint through --load."
    assert args.train_dataset is not None, "Please provide a training dataset."
    
    
    image_classifier = ImageClassifier.load(args.load)

    if args.freeze_encoder:
        print('Fine-tuning a linear classifier')
        model = image_classifier.classification_head
        input_key = 'features'
        preprocess_fn = image_classifier.val_preprocess
        image_enc = image_classifier.image_encoder
        print_every = 1000
    else:
        print('Fine-tuning end-to-end')
        model = image_classifier
        input_key = 'images'
        preprocess_fn = image_classifier.train_preprocess
        image_enc = None
        image_classifier.process_images = True
        print_every = 100
    
    dataset_class = getattr(datasets, args.train_dataset)
    dataset = dataset_class(
        preprocess_fn,
        location=args.data_location,
        batch_size=args.batch_size
    )
    num_batches = len(dataset.train_loader)

    model = model.cuda()
    devices = list(range(torch.cuda.device_count()))
    print('Using devices', devices)
    model = torch.nn.DataParallel(model, device_ids=devices)
    model.train()

    if args.ls > 0:
        loss_fn = LabelSmoothing(args.ls)
    else:
        loss_fn = torch.nn.CrossEntropyLoss()

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd)

    scheduler = cosine_lr(optimizer, args.lr, args.warmup_length, args.epochs * num_batches)
    pca_time_total = 0
    reg_time = 0    
    # Regularization setup
    lambda_reg = args.lambda_reg
    max_degree = args.max_degree
    resolution = int(args.resolution)
    miu = args.miu
    have_const = args.have_const
    use_norm = args.use_norm
    
    # EMA for loss
    ema_alpha = 0.1
    ema_train_loss = None
    
    # Precompute alpha_values
    cached = precompute_chebyshev_matrix(resolution, max_degree)
    alpha_values = cached['alpha_values'].to(torch.float32)
    # [Opt] Precompute Inner Alpha values in deterministic mode
    # Keep only the inner segment for the Mixup process
    n_inner_global = resolution - 2
    if not args.random_alpha:
        if n_inner_global > 0:
            alpha_inner_only = alpha_values[1:-1] 
        else:
            alpha_inner_only = torch.tensor([])
        
        # Map to Mixup space
        alpha_inner_sample = alpha_inner_only.view(-1, 1, 1, 1)
        if args.label:
            alpha_inner_sample = (alpha_inner_sample + 1) * 0.5
        else:
            alpha_inner_sample = -0.1 + (alpha_inner_sample + 1) * 0.6
    # Fixed pair for landscape plotting
    fixed_pair_images = None
    fixed_pair_labels = None
    
    # Get a fixed pair from the dataset for visualization
    temp_loader = get_dataloader(dataset, is_train=True, args=args, image_encoder=image_enc)
    for batch in temp_loader:
        batch = maybe_dictionarize(batch)
        inputs = batch[input_key]
        labels = batch['labels']
        if inputs.size(0) >= 2:
            fixed_pair_images = [inputs[0].cpu(), inputs[1].cpu()]
            fixed_pair_labels = [labels[0].item(), labels[1].item()]
            break
    print(f"Fixed pair labels for landscape: {fixed_pair_labels}")

    train_accs, test_accs, train_losses, test_losses, reg_terms = [], [], [], [], []
    args.eval_datasets = ['ImageNet']
    for epoch in range(args.epochs):
        model.train()
        lambda_reg = adjust_lambda_reg(epoch, args)
        train_loss_total = 0.0
        reg_term_total = 0.0
        correct = 0
        total = 0
        total_time = 0.0
        data_loader = get_dataloader(
            dataset, is_train=True, args=args, image_encoder=image_enc)

        for i, batch in enumerate(data_loader):
            start_time = time.time()
            
            step = i + epoch * num_batches
            scheduler(step)
            optimizer.zero_grad()

            batch = maybe_dictionarize(batch)
            inputs = batch[input_key].cuda()
            labels = batch['labels'].cuda()
            data_time = time.time() - start_time

            # Regularization logic
            compute_reg_for_loss = (lambda_reg > -0.00001)
            num_pairs_actual = args.num_pairs if compute_reg_for_loss else 1
            can_compute_reg = (inputs.size(0) >= num_pairs_actual * 2)
            
            reg_term = torch.tensor(0.0, device=inputs.device)
            
            if can_compute_reg:
                N = inputs.size(0)
                x1_indices = torch.randint(low=0, high=N, size=(num_pairs_actual,), device=inputs.device)
                # Generate x2 != x1 by adding a random shift in [1, N-1] to x1 (modulo N)
                offset = torch.randint(low=1, high=N, size=(num_pairs_actual,), device=inputs.device)
                x2_indices = (x1_indices + offset) % N
                
                if args.label:
                    label1_batch = labels[x1_indices]
                    label2_batch = labels[x2_indices]
                
                x1_batch = inputs[x1_indices]
                x2_batch = inputs[x2_indices]

                
                # Prepare alpha_values_sample based on input dimensions
                # inputs: [B, C, H, W] or [B, D]
                input_dim = inputs.dim()
                
                # Stores the complete alpha sequence for math operations
                full_alpha_list = []
                # Stores only the inner mixup alpha (for net computation)
                inner_mixup_alpha_list = []
                
                # === 1. Generate sampling points (Sampling) ===
                if args.random_alpha:
                    # Dynamic number of inner points
                    n_inner = resolution - 2
                    
                    for _ in range(num_pairs_actual):
                        # --- Cosine stratified sampling with Padding ---
                        if n_inner > 0:
                            steps = torch.arange(n_inner, device=inputs.device, dtype=torch.float32)
                            jitter = torch.rand(n_inner, device=inputs.device, dtype=torch.float32)
                            raw_fraction = (steps + jitter) / n_inner
                            
                            # Dynamic Padding
                            padding = 1.0 / resolution 
                            compressed_fraction = padding + raw_fraction * (1 - 2 * padding)
                            
                            # Map to angles and project
                            theta = compressed_fraction * np.pi
                            alpha_inner_cheb = -torch.cos(theta)
                            alpha_inner_cheb, _ = torch.sort(alpha_inner_cheb)
                        else:
                            alpha_inner_cheb = torch.tensor([], device=inputs.device)
                        
                        # 1.1 Construct the complete Alpha [-1, inner, 1] (for regularization calculation)
                        alpha_values_pair = torch.cat([
                            torch.tensor([-1.0], device=inputs.device), 
                            alpha_inner_cheb, 
                            torch.tensor([1.0], device=inputs.device)
                        ])
                        full_alpha_list.append(alpha_values_pair)
                        
                        # 1.2 Construct Inner Mixup Alpha (for image generation)
                        if n_inner > 0:
                            # Reshape [n_inner, 1, 1, 1]
                            view_shape = [n_inner] + [1] * (input_dim - 1)
                            alpha_inner_mixup = alpha_inner_cheb.view(*view_shape)
                            
                            if args.label:
                                alpha_inner_mixup = (alpha_inner_mixup + 1) * 0.5
                            else:
                                alpha_inner_mixup = -0.1 + (alpha_inner_mixup + 1) * 0.6
                            
                            inner_mixup_alpha_list.append(alpha_inner_mixup)
                    
                    # Stack inner alphas: [num_pairs, n_inner, 1, ...]
                    if n_inner > 0:
                        inner_alpha_device = torch.stack(inner_mixup_alpha_list)
                    else:
                        inner_alpha_device = None
                        
                else:
                    # Deterministic mode
                    inner_alpha_device = alpha_inner_sample.to(inputs.device) # [n_inner, 1, ...]
                    # Add num_pairs dimension: [1, n_inner, 1, ...]
                    if n_inner_global > 0:
                        inner_alpha_device = inner_alpha_device.unsqueeze(0)
                    else:
                        inner_alpha_device = None
                
                # === 2. Generate Mixup samples (Inner only) ===
                samples_flat = torch.tensor([], device=inputs.device)
                
                if inner_alpha_device is not None and inner_alpha_device.numel() > 0:
                    x1_expanded = x1_batch.unsqueeze(1) # [num_pairs, 1, C, H, W]
                    x2_expanded = x2_batch.unsqueeze(1)
                    
                    if not args.random_alpha:
                        # Expand to match num_pairs: [num_pairs, n_inner, ...]
                        inner_alpha_device = inner_alpha_device.expand(num_pairs_actual, -1, *([-1]*(input_dim-1)))
                    
                    # Generate inner point images
                    inner_samples = x1_expanded + inner_alpha_device * (x2_expanded - x1_expanded)
                    # Flatten: [num_pairs * n_inner, C, H, W]
                    samples_flat = inner_samples.reshape(-1, *inputs.shape[1:])

                # === 3. Merge inputs and forward pass ===
                # Only Normal Inputs and Inner Samples are fed into the network
                if samples_flat.numel() > 0:
                    all_inputs = torch.cat([inputs, samples_flat], dim=0)
                else:
                    all_inputs = inputs
                
                all_outputs = model(all_inputs)
                
                # === 4. Split outputs & concatenate sequences ===
                batch_size = inputs.size(0)
                outputs = all_outputs[:batch_size] # Normal Loss outputs
                loss = loss_fn(outputs, labels)
                
                num_classes = outputs.size(1) # Dynamically get number of classes
                
                # Process Inner Outputs
                if samples_flat.numel() > 0:
                    inner_outputs_flat = all_outputs[batch_size:]
                    inner_probs = torch.nn.functional.softmax(inner_outputs_flat, dim=1)
                    # Reshape: [num_pairs, n_inner, num_classes]
                    inner_probs_reshaped = inner_probs.reshape(num_pairs_actual, resolution - 2, -1)
                else:
                    inner_probs_reshaped = torch.tensor([], device=inputs.device).reshape(num_pairs_actual, 0, num_classes)

                # === 5. Construct endpoint Hard Labels ===
                # Ensure dtype consistency
                target_dtype = outputs.dtype
                endpoint_1 = torch.nn.functional.one_hot(label1_batch, num_classes=num_classes).to(dtype=target_dtype, device=inputs.device).unsqueeze(1)
                endpoint_2 = torch.nn.functional.one_hot(label2_batch, num_classes=num_classes).to(dtype=target_dtype, device=inputs.device).unsqueeze(1)
                
                # === 6. Concatenate full sequence ===
                # [Hard 1] - [Soft Inner] - [Hard 2]
                full_sequence = torch.cat([endpoint_1, inner_probs_reshaped, endpoint_2], dim=1)
                
                
                if args.pca_reg:
                    full_sequence = pca(full_sequence, num_pairs_actual, k=args.pca_k)

                # Prepare Alpha inputs
                if args.random_alpha:
                    batched_alpha_inputs = torch.stack(full_alpha_list)
                    alpha_in_dims = 0
                else:
                    batched_alpha_inputs = alpha_values.to(inputs.device)
                    alpha_in_dims = None
                # 2. Define wrapper function (closure) to bind unchanged parameters
                def reg_wrapper(alpha, sample_output):
                    return polynomial_regularization(
                        alpha, 
                        sample_output, 
                        resolution, 
                        miu, 
                        max_degree, 
                        have_const,
                        use_norm,
                        args.random_alpha,
                        args.square,
                        args.degree_mode
                    )


                batched_reg_func = vmap(reg_wrapper, in_dims=(alpha_in_dims, 0))

                # 4. Parallel execution
                reg_terms_tensor = batched_reg_func(batched_alpha_inputs, full_sequence)

                reg_term = torch.mean(reg_terms_tensor)
                
            else:
                outputs = model(inputs)
                loss = loss_fn(outputs, labels)
            
            # Update EMA loss
            if ema_train_loss is None:
                ema_train_loss = loss.item()
            else:
                ema_train_loss = ema_alpha * loss.item() + (1 - ema_alpha) * ema_train_loss
            
            # Add regularization to loss
            if compute_reg_for_loss and can_compute_reg:
                if args.reg_anneal:
                    # Need num_classes for train_loss_max
                    num_classes = outputs.size(1)
                    train_loss_max = np.log(num_classes)
                    lambda_reg_current = max(args.min_lambda_reg, args.lambda_reg * ema_train_loss / train_loss_max)
                else:
                    lambda_reg_current = lambda_reg
                
                loss = loss + lambda_reg_current * reg_term
                # loss = lambda_reg_current * reg_term

            loss.backward()
            
            # Update stats
            train_loss_total += loss.item()
            reg_term_total += reg_term.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            torch.nn.utils.clip_grad_norm_(params, 1.0)

            optimizer.step()
            batch_time = time.time() - start_time
            total_time += batch_time
            if i % print_every == 0:
                percent_complete = 100 * i / len(data_loader)
                print(
                    f"Train Epoch: {epoch} [{percent_complete:.0f}% {i}/{len(dataset.train_loader)}]\t"
                    f"[Epoch {epoch}] PCA total time: {pca_time_total:.4f} seconds\t"
                    f"[Epoch {epoch}] Reg total time: {reg_time:.4f} seconds\t"
                    f"Loss: {loss.item():.6f}\tReg: {reg_term.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {total_time / (i+1):.3f}", flush=True
                )
        
        if args.freeze_encoder:
            image_classifier = ImageClassifier(image_classifier.image_encoder, model.module)
        else:
            image_classifier = model.module
            
        # Calculate epoch stats
        avg_train_loss = train_loss_total / len(data_loader)
        avg_reg_term = reg_term_total / len(data_loader)
        avg_train_acc = 100. * correct / total
        
        train_losses.append(avg_train_loss)
        reg_terms.append(avg_reg_term)
        train_accs.append(avg_train_acc)

        # Saving model
        if args.save is not None:
            os.makedirs(args.save, exist_ok=True)
            model_path = os.path.join(args.save, f'checkpoint_{epoch+1}.pt')
            print('Saving model to', model_path)
            image_classifier.save(model_path)
            optim_path = os.path.join(args.save, f'optim_{epoch+1}.pt')
            torch.save(optimizer.state_dict(), optim_path)

        # Evaluate
        args.current_epoch = epoch
        eval_results = evaluate(image_classifier, args)

        
        # Extract test stats
        if args.eval_datasets:
            main_dataset = args.eval_datasets[0]
            test_acc = eval_results.get(f'{main_dataset}:top1', 0.0) * 100.0
            test_loss = eval_results.get(f'{main_dataset}:loss', 0.0)
            test_accs.append(test_acc)
            test_losses.append(test_loss)
            
            print(f"Epoch {epoch} Summary: Train Acc: {avg_train_acc:.2f}, Test Acc: {test_acc:.2f}, Train Loss: {avg_train_loss:.4f}, Test Loss: {test_loss:.4f}, Reg Term: {avg_reg_term:.6f}")
        
        plot_training_curves(args, train_accs, test_accs, train_losses, test_losses, reg_terms, save_path=args.save if args.save else 'training_plots')

    if args.save is not None:
        return model_path


if __name__ == '__main__':
    args = parse_arguments()
    finetune(args)
