import os
import time
import torch
import random
import numpy as np
from copy import deepcopy

from src.datasets import get_dataloader, get_dataset, maybe_dictionarize
from src.eval.eval import eval_single_dataset
from src.models import ImageClassifier, ImageEncoder, get_classification_head
from src.utils.utils import LabelSmoothing, cosine_lr

def set_seed(seed=42):
    """Set all random seeds to ensure reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
def continual_finetune(args, train_dataset, starting_model_path, output_path):
    """
    Fine-tune the specified dataset based on the specified starting model
    
    Args:
        args: Parameter object
        train_dataset: Training dataset name
        starting_model_path: Path to the starting model
        output_path: Path to save the fine-tuned model
    """
    # Set random seed for reproducibility
    seed = 42
    set_seed(seed)
    
    print(f"Starting fine-tuning on dataset {train_dataset}")
    print(f"Starting model: {starting_model_path}")
    print(f"Output path: {output_path}")
    
    # Configure fine-tuning parameters
    epochs_map = {
        "Cars": 35, "DTD": 76, "EuroSAT": 12, "GTSRB": 11,
        "MNIST": 5, "RESISC45": 15, "SUN397": 14, "SVHN": 4,
        "CIFAR10": 6, "CIFAR100": 6, "STL10": 60, "Food101": 4,
        "Flowers102": 147, "FER2013": 10, "PCAM": 1, "OxfordIIITPet": 82,
        "RenderedSST2": 39, "EMNIST": 2, "FashionMNIST": 5, "KMNIST": 5,
    }
    
    base_name = train_dataset.replace("Val", "")
    epochs = epochs_map[base_name] if base_name in epochs_map else 10
    
    # Load starting model
    image_encoder = ImageEncoder(args.model)
    image_encoder.load_state_dict(torch.load(starting_model_path))
    
    # Create classification head
    classification_head = get_classification_head(args, train_dataset)
    model = ImageClassifier(image_encoder, classification_head)
    
    # Freeze classification head to focus on fine-tuning feature extractor
    model.freeze_head()
    model = model.cuda()
    
    # Prepare dataset
    preprocess_fn = model.train_preprocess
    dataset = get_dataset(
        train_dataset,
        preprocess_fn,
        location=args.data_location,
        batch_size=args.batch_size if hasattr(args, 'batch_size') else 64,
    )
    
    data_loader = get_dataloader(
        dataset, 
        is_train=True, 
        args=args, 
        image_encoder=None,
    )
    num_batches = len(dataset.train_loader)
    loader = data_loader
    
    # Set loss function
    ls = args.ls if hasattr(args, 'ls') else 0.0
    if ls > 0:
        loss_fn = LabelSmoothing(ls)
    else:
        loss_fn = torch.nn.CrossEntropyLoss()
    
    # Configure optimizer
    lr = args.lr if hasattr(args, 'lr') else 1e-5
    wd = args.wd if hasattr(args, 'wd') else 0.1
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=wd)
    
    # Configure learning rate scheduler
    warmup_length = args.warmup_length if hasattr(args, 'warmup_length') else 0.1
    num_grad_accumulation = args.num_grad_accumulation if hasattr(args, 'num_grad_accumulation') else 2
    scheduler = cosine_lr(
        optimizer,
        lr,
        warmup_length,
        epochs * num_batches // num_grad_accumulation,
    )
    
    # Training loop
    best_model = None
    best_accuracy = -1.0
    
    print(f"Starting training, total {epochs} epochs")
    for epoch in range(epochs):
        model.train()
        
        # Initialize tracking variables for each epoch
        epoch_loss = 0.0
        epoch_data_time = 0.0
        epoch_batch_time = 0.0
        num_epoch_batches = 0
        
        for i, batch in enumerate(loader):
            start_time = time.time()
            
            step = (
                i // num_grad_accumulation
                + epoch * num_batches // num_grad_accumulation
            )
            
            batch = maybe_dictionarize(batch)
            inputs = batch["images"].cuda()
            labels = batch["labels"].cuda()
            data_time = time.time() - start_time
            
            logits = model(inputs)
            loss = loss_fn(logits, labels)
            loss.backward()
            
            if (i + 1) % num_grad_accumulation == 0:
                scheduler(step)
                
                torch.nn.utils.clip_grad_norm_(params, 1.0)
                optimizer.step()
                optimizer.zero_grad()
            
            batch_time = time.time() - start_time
            
            # Accumulate epoch statistics
            epoch_loss += loss.item()
            epoch_data_time += data_time
            epoch_batch_time += batch_time
            num_epoch_batches += 1
            
            # Print progress
            if i % 10 == 0:
                print(f"Epoch {epoch+1}/{epochs}, Batch {i+1}/{num_batches}, Loss: {loss.item():.6f}", end="\r")
        
        # Print summary information at the end of each epoch
        avg_loss = epoch_loss / num_epoch_batches
        avg_data_time = epoch_data_time / num_epoch_batches
        avg_batch_time = epoch_batch_time / num_epoch_batches
        
        print(
            f"Train Epoch: {epoch+1}/{epochs} completed\t"
            f"Avg Loss: {avg_loss:.6f}\t"
            f"Avg Data time: {avg_data_time:.3f}s\t"
            f"Avg Batch time: {avg_batch_time:.3f}s\t"
        )
        
        # Evaluate current epoch's model
        if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
            acc = eval_single_dataset(image_encoder, train_dataset, args)['top1']
            print(f"Epoch {epoch+1} Validation accuracy: {acc*100:.2f}%")
            
            # Save best model
            if acc > best_accuracy:
                best_accuracy = acc
                best_model = deepcopy(image_encoder.state_dict())
                print(f"Found new best model, accuracy: {best_accuracy*100:.2f}%")
    
    # Save final fine-tuned model (use best model or last epoch's model)
    final_model = best_model if best_model is not None else image_encoder.state_dict()
    torch.save(final_model, output_path)
    
    # Return final accuracy
    final_acc = best_accuracy if best_model is not None else acc
    print(f"Fine-tuning completed. Best accuracy: {final_acc*100:.2f}%")
    print(f"Fine-tuned model saved to: {output_path}")
    
    return final_acc