# Copyright 2023 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import json
import os
import sys
from pathlib import Path
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
from avalanche.benchmarks import SplitCIFAR10, SplitCIFAR100
from omegaconf import OmegaConf
from sklearn.metrics import euclidean_distances
from torch.utils.data import DataLoader, TensorDataset, Subset
from torchvision import transforms
from tqdm import tqdm

from solo.args.knn import parse_args_knn
from solo.data.classification_dataloader import (
    prepare_dataloaders,
    prepare_datasets,
    prepare_transforms,
)
from solo.methods import METHODS
from solo.utils.knn import WeightedKNNClassifier
import torchvision.models as models
import torch.nn.functional as F


def select_random_even(targets, samples_per_class=3):
    targets = targets.cpu().numpy()
    unique_classes = np.unique(targets)
    num_classes = len(unique_classes)
    selected_indices = []

    print("Available samples per class:")
    for cls in unique_classes:
        class_indices = np.where(targets == cls)[0]
        print(f"Class {cls}: {len(class_indices)} samples available")
        num_samples = min(samples_per_class, len(class_indices))
        chosen_indices = np.random.choice(class_indices, num_samples, replace=False)
        selected_indices.extend(chosen_indices)
        print(f"Selected {num_samples} samples for class {cls}")

    return np.array(selected_indices)
def select_least_fat_vectors_evenly(features, targets, total_samples=30):
    features = features.cpu()
    targets = targets.cpu()

    unique_classes = torch.unique(targets)
    num_classes = len(unique_classes)
    samples_per_class = total_samples // num_classes

    selected_indices = []

    for cls in unique_classes:
        class_mask = (targets == cls)
        class_features = features[class_mask]

        # Calculate the norm (magnitude) of each vector in the current class
        norms = torch.norm(class_features, dim=1)

        # Select bottom samples_per_class with lowest norms from the current class
        num_samples = min(samples_per_class, len(norms))
        _, least_fat_indices = torch.topk(norms, k=num_samples,
                                          largest=False)  # Changed to topk with largest=False

        # Map back to original indices
        original_indices = torch.where(class_mask)[0][least_fat_indices]
        selected_indices.extend(original_indices.tolist())

    # If we have fewer samples than total_samples, repeat the process to fill the gap
    while len(selected_indices) < total_samples:
        for cls in unique_classes:
            if len(selected_indices) >= total_samples:
                break
            class_mask = (targets == cls)
            remaining_indices = set(torch.where(class_mask)[0].tolist()) - set(selected_indices)
            if remaining_indices:
                additional_index = min(remaining_indices,
                                       key=lambda i: torch.norm(features[i]))  # Changed to min()
                selected_indices.append(additional_index)

    # Ensure we have exactly total_samples
    selected_indices = selected_indices[:total_samples]

    return torch.tensor(selected_indices, device=features.device)


def select_random_vectors_evenly(features, targets, total_samples=30):
    features = features.cpu()
    targets = targets.cpu()

    unique_classes = torch.unique(targets)
    num_classes = len(unique_classes)
    samples_per_class = total_samples // num_classes

    selected_indices = []

    for cls in unique_classes:
        class_mask = (targets == cls)
        class_features = features[class_mask]

        # Calculate the norm (magnitude) of each vector in the current class
        norms = torch.norm(class_features, dim=1)

        # Select bottom samples_per_class with lowest norms from the current class
        num_samples = min(samples_per_class, len(norms))
        _, least_fat_indices = torch.topk(norms, k=num_samples,
                                          largest=True)  # Changed to topk with largest=True

        # Map back to original indices
        original_indices = torch.where(class_mask)[0][least_fat_indices]
        selected_indices.extend(original_indices.tolist())

    # If we have fewer samples than total_samples, repeat the process to fill the gap
    while len(selected_indices) < total_samples:
        for cls in unique_classes:
            if len(selected_indices) >= total_samples:
                break
            class_mask = (targets == cls)
            remaining_indices = set(torch.where(class_mask)[0].tolist()) - set(selected_indices)
            if remaining_indices:
                additional_index = min(remaining_indices,
                                       key=lambda i: torch.norm(features[i]))  # Changed to min()
                selected_indices.append(additional_index)

    # Ensure we have exactly total_samples
    selected_indices = selected_indices[:total_samples]

    return torch.tensor(selected_indices, device=features.device)
def extract_features_with_aug(
        loader: TensorDataset,
        model: nn.Module,
        transform,
        num_aug_per_class: int = 5000, # Modified, now this controls the AUGMENTATIONS per class, *not* per sample
        batch_size: int = 256,
        device: torch.device = None
) -> Tuple[torch.Tensor]:
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model.to(device)
    model.eval()

    backbone_features, proj_features, labels = [], [], []
    initial_counts = {}

    # First pass: Extract features of original images and count initial samples per class
    for im, lab in tqdm(loader, desc="Processing original images"):
        im = im.to(device, non_blocking=True)
        lab = lab.to(device, non_blocking=True)

        with torch.no_grad():
            outs = model(im)
            backbone_features.append(outs["feats"].detach())
            proj_features.append(outs["z"])
            labels.append(lab)

        for l in lab:
            l_item = l.item()
            initial_counts[l_item] = initial_counts.get(l_item, 0) + 1

    print("Initial class counts:", initial_counts)  # Debug print

    # Calculate augmentations needed per class
    augmentations_needed = {}
    for cls, initial_count in initial_counts.items():
        augmentations_needed[cls] = num_aug_per_class - initial_count
        print(f"Class {cls}: Initial count = {initial_count}, Augmentations needed = {augmentations_needed[cls]}")  # Debug

    # Second pass: Add augmentations to reach the desired sample count per class
    for cls in initial_counts.keys():  # Iterate over unique class labels
        num_to_augment = augmentations_needed[cls]  # Get number of augmentations needed for this class
        num_augmented = 0  # Counter for created augmentations
        loader.dataset.transform = transform # Apply transformation

        # Augment the dataset with new samples
        while num_augmented < num_to_augment:
            num_samples_to_add = min(batch_size, num_to_augment - num_augmented) # The minimum of needed augmentations and batch size
            # Here we randomly sample images *from the current class*. This ensures that augmentations are balanced.
            class_samples = [loader.dataset[i][0] for i in range(len(loader.dataset)) if loader.dataset[i][1] == cls] #Selects for the correct label, for only the samples that are the right class.
            if not class_samples:  # If there are no samples of this class
                print(f"Warning: No samples of class {cls} found in the loader.")
                break
            # Randomly sample samples from that class, number controlled by min() above
            indices = np.random.choice(len(class_samples), size=num_samples_to_add, replace=True) #Number of samples to add.

            images_to_augment = torch.stack([transform(class_samples[i]) for i in indices]).to(device, non_blocking=True) #Actually adds in the images.
            batch_labels = torch.full((len(images_to_augment),), cls, device=device) #Make the batch labels explicitly to match

            num_augmented += num_samples_to_add

            with torch.no_grad():
                aug_outs = model(images_to_augment)
                backbone_features.append(aug_outs["feats"].detach())
                proj_features.append(aug_outs["z"])
                labels.append(batch_labels)
            del images_to_augment, aug_outs
            torch.cuda.empty_cache()
    print("Augmentation complete")

    model.train()

    # Concatenate everything while keeping on the same device
    backbone_features = torch.cat(backbone_features)
    proj_features = torch.cat(proj_features)
    labels = torch.cat(labels)

    # Calculate and print final class counts
    final_counts = {}
    for l in labels:
        l_item = l.item()
        final_counts[l_item] = final_counts.get(l_item, 0) + 1

    print("Final class counts:", final_counts)

    # Verify shapes match
    assert backbone_features.size(0) == labels.size(0), \
        f"Shape mismatch: features {backbone_features.size(0)} vs labels {labels.size(0)}"
    assert proj_features.size(0) == labels.size(0), \
        f"Shape mismatch: proj_features {proj_features.size(0)} vs labels {labels.size(0)}"

    return backbone_features, proj_features, labels


@torch.no_grad()
def extract_features(loader: DataLoader, model: nn.Module) -> Tuple[torch.Tensor]:
    """Extract features from a data loader using a model.

    Args:
        loader (DataLoader): dataloader for a dataset.
        model (nn.Module): torch module used to extract features.

    Returns:
        Tuple(torch.Tensor): tuple containing the backbone features, projector features and labels.
    """

    model.eval()
    backbone_features, proj_features, labels = [], [], []
    for im, lab in tqdm(loader):
        im = im.cuda(non_blocking=True)
        lab = lab.cuda(non_blocking=True)
        outs = model(im)
        backbone_features.append(outs["feats"].detach())
        proj_features.append(outs["z"])
        labels.append(lab)
    model.train()
    backbone_features = torch.cat(backbone_features)
    proj_features = torch.cat(proj_features)
    labels = torch.cat(labels)
    return backbone_features, proj_features, labels


    return sampled_features, sampled_targets
@torch.no_grad()
def run_knn(
    train_features: torch.Tensor,
    train_targets: torch.Tensor,
    test_features: torch.Tensor,
    test_targets: torch.Tensor,
    k: int,
    T: float,
    distance_fx: str,
) -> Tuple[float]:
    """Runs offline knn on a train and a test dataset.

    Args:
        train_features (torch.Tensor, optional): train features.
        train_targets (torch.Tensor, optional): train targets.
        test_features (torch.Tensor, optional): test features.
        test_targets (torch.Tensor, optional): test targets.
        k (int): number of neighbors.
        T (float): temperature for the exponential. Only used with cosine
            distance.
        distance_fx (str): distance function.

    Returns:
        Tuple[float]: tuple containing the the knn acc@1 and acc@5 for the model.
    """

    # build knn
    knn = WeightedKNNClassifier(
        k=k,
        T=T,
        distance_fx=distance_fx,
    )

    # add features
    knn(
        train_features=train_features,
        train_targets=train_targets,
        test_features=test_features,
        test_targets=test_targets,
    )

    # compute
    acc1, acc5 = knn.compute()

    # free up memory
    del knn

    return acc1, acc5


# extract train features
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([
        transforms.ColorJitter(0.6, 0.6, 0.6, 0.2)  # increased intensity
    ], p=0.3),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomApply([
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
    ], p=0.2),
    transforms.RandomRotation(degrees=15),
    transforms.RandomPerspective(distortion_scale=0.4, p=0.1),
    transforms.RandomAffine(
        degrees=30,
        translate=(0.2, 0.2),
        scale=(0.8, 1.2),
        shear=15
    )
])
def save_representations(train_features_bb, train_targets, output_path, prefix):
    features_np = train_features_bb.cpu().numpy()
    labels_np = train_targets.cpu().numpy()

    # Find unique classes
    unique_classes = np.unique(labels_np)

    # Save representations by unique class
    for class_id in unique_classes:
        class_mask = (labels_np == class_id)
        class_representations = features_np[class_mask]
        class_labels = labels_np[class_mask]

        # Save individual class file
        class_file_path = os.path.join(output_path, f'{prefix}_{int(class_id)}.npy')
        np.save(class_file_path, {
            'features': class_representations,
            'labels': class_labels
        })

        # Save all-inclusive class file
        all_class_file_path = os.path.join(output_path, f'{prefix}_all_{int(class_id)}.npy')
        np.save(all_class_file_path, {
            'features': class_representations,
            'labels': class_labels
        })

        print(f"Saved class {int(class_id)} representations: {class_representations.shape}")


def main():
    args = parse_args_knn()

    # build paths
    ckpt_dir = Path(args.pretrained_checkpoint_dir)
    args_path = ckpt_dir / "args.json"
    ckpt_path = [ckpt_dir / ckpt for ckpt in os.listdir(ckpt_dir) if ckpt.endswith(".ckpt")][0]

    # load arguments
    with open(args_path) as f:
        method_args = json.load(f)
    cfg = OmegaConf.create(method_args)

    # build the model
    model = METHODS[method_args["method"]].load_from_checkpoint(ckpt_path, strict=False, cfg=cfg)

    # prepare data
    _, T = prepare_transforms(args.dataset)
    train_dataset, val_dataset = prepare_datasets(
        args.dataset,
        T_train=T,
        T_val=T,
        train_data_path=args.train_data_path,
        val_data_path=args.val_data_path,
        data_format=args.data_format,
    )
    if args.dataset == "cifar10":

        scenario = SplitCIFAR10(
            n_experiences=5,
            return_task_id=False,
            fixed_class_order=[i for i in range(10)],
            shuffle=False,
            seed=args.seed
        )
    elif args.dataset == "cifar100":
        scenario = SplitCIFAR100(
            n_experiences=10,
            return_task_id=False,
            fixed_class_order=[i for i in range(100)],
            shuffle=False,
            seed=args.seed
        )
    current_indices = scenario.train_stream[args.exp].dataset._flat_data._indices
    test_indices = scenario.test_stream[args.exp].dataset._flat_data._indices

    train_dataset_cur = torch.utils.data.Subset(train_dataset, current_indices)
    for i in range(args.exp):
        buffer_val_indices = scenario.test_stream[args.exp-1-i].dataset._flat_data._indices
        test_indices = np.concatenate((test_indices, buffer_val_indices))
    val_dataset = torch.utils.data.Subset(val_dataset, test_indices)
    train_loader, val_loader = prepare_dataloaders(
        train_dataset_cur,
        val_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )
    train_features_bb, train_features_proj, train_targets = extract_features(train_loader, model)
    for i in range(args.exp):
        # Get indices of previous experiences
        buffer_indices = scenario.train_stream[args.exp - 1 - i].dataset._flat_data._indices
        buffer_indices = np.array(buffer_indices)

        # Create a subset dataset from buffer indices
        buffer_dataset = Subset(train_dataset, buffer_indices)

        # Create DataLoader
        buffer_loader, _ = prepare_dataloaders(
            buffer_dataset, val_dataset, batch_size=args.batch_size, num_workers=args.num_workers
        )

        # Collect targets
        buffer_targets = torch.tensor([buffer_dataset.dataset[buffer_dataset.indices[i]][1] for i in range(len(buffer_dataset))])


        print("Buffer targets before selection:", torch.unique(buffer_targets, return_counts=True))
        num_classes = len(torch.unique(buffer_targets))
        num_samples_per_class = 300 // (num_classes * args.exp)
        selected_indices = select_random_even(buffer_targets, samples_per_class=num_samples_per_class)
        selected_dataset = Subset(buffer_dataset, selected_indices.tolist())

        print("Unique labels in selected dataset:", torch.unique(buffer_targets[selected_indices], return_counts=True))
        selected_loader = DataLoader(selected_dataset, batch_size=args.batch_size, shuffle=False,
                                     num_workers=args.num_workers)
        # Use selected_indices to get the most centered vectors
        if args.dataset == "cifar10":
            buffer_features_bb, buffer_features_proj, buffer_targets = extract_features_with_aug(selected_loader, model, transform)
        else:
            buffer_features_bb, buffer_features_proj, buffer_targets = extract_features_with_aug(selected_loader, model, transform, num_aug_per_class=5000)
        train_features_bb = torch.cat((train_features_bb, buffer_features_bb))
        train_targets = torch.cat((train_targets, buffer_targets))
    train_features = {"backbone": train_features_bb, "projector": train_features_proj}
    # save_representations(
    #     train_features_bb,
    #     train_targets,
    #     ckpt_dir,
    #     prefix=f'{args.dataset}_{method_args["method"]}_trying_train'
    # )

    # save_representations(
    #     buffer_features_bb,
    #     buffer_targets,
    #     ckpt_dir,
    #     prefix=f'cifar10_{method_args["method"]}_trying_train_buffer'
    # )
    # extract test features
    test_features_bb, test_features_proj, test_targets = extract_features(val_loader, model)
    test_features = {"backbone": test_features_bb, "projector": test_features_proj}
    save_representations(
        test_features_bb,
        test_targets,
        ckpt_dir,
        prefix=f'{args.dataset}_{method_args["method"]}_trying_test'
    )
    # run k-nn for all possible combinations of parameters
    model =  FeatureClassifier(input_dim=512, num_classes=100)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                           factor=0.5, patience=10, verbose=True)

    train_dataset = TensorDataset(train_features_bb, train_targets)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    print("\nInitial Diagnostics:")
    print(f"Train features shape: {train_features_bb.shape}")
    print(f"Test features shape: {test_features_bb.shape}")
    print(f"Train class distribution: {torch.unique(train_targets, return_counts=True)}")
    print(f"Test class distribution: {torch.unique(test_targets, return_counts=True)}")
    test_acc=0
    val_dataset = TensorDataset(test_features_bb, test_targets)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    model = model.to('cuda')
    best_acc = 0
    epochs=100
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        correct = 0
        total = 0

        for batch_features, batch_labels in train_loader:
            batch_features = batch_features.to('cuda')
            batch_labels = batch_labels.to('cuda')

            optimizer.zero_grad()
            outputs = model(batch_features)
            loss = criterion(outputs, batch_labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += batch_labels.size(0)
            correct += predicted.eq(batch_labels).sum().item()

        train_acc = 100. * correct / total

        # Validation
        # model.eval()
        # test_loss = 0
        # correct = 0
        # total = 0
        #
        # with torch.no_grad():
        #     for batch_features, batch_labels in val_loader:
        #         batch_features = batch_features.to('cuda')
        #         batch_labels = batch_labels.to('cuda')
        #
        #         outputs = model(batch_features)
        #         loss = criterion(outputs, batch_labels)
        #
        #         test_loss += loss.item()
        #         _, predicted = outputs.max(1)
        #         total += batch_labels.size(0)
        #         correct += predicted.eq(batch_labels).sum().item()
        #
        # test_acc = 100. * correct / total
        # scheduler.step(test_acc)
        #
        # # Save best model
        # if test_acc > best_acc:
        #     best_acc = test_acc
        #     best_model_state = model.state_dict().copy()
        #
        # if (epoch + 1) % 5 == 0:
        #     print(f'Epoch {epoch + 1}/{epochs}:')
        #     print(f'Train Loss: {train_loss / len(train_loader):.4f}, '
        #           f'Train Acc: {train_acc:.2f}%')
        #     print(f'Test Loss: {test_loss / len(val_loader):.4f}, '
        #           f'Test Acc: {test_acc:.2f}%')
    for feat_type in args.feature_type:
        print(f"\n### {feat_type.upper()} ###")
        for k in args.k:
            for distance_fx in args.distance_function:
                temperatures = args.temperature if distance_fx == "cosine" else [None]
                for T in temperatures:
                    print("---")
                    print(f"Running k-NN with params: distance_fx={distance_fx}, k={k}, T={T}...")
                    acc1, acc5 = run_knn(
                        train_features=train_features[feat_type],
                        train_targets=train_targets,
                        test_features=test_features[feat_type],
                        test_targets=test_targets,
                        k=k,
                        T=T,
                        distance_fx=distance_fx,
                    )
                    print(f"Result: acc@1={acc1}, acc@5={acc5}")
    return acc1, test_acc


class FeatureClassifier(nn.Module):
    def __init__(self, input_dim=512, num_classes=4):
        super().__init__()

        # Initial layers
        self.initial = nn.Sequential(
            nn.BatchNorm1d(input_dim),
            nn.Linear(input_dim, 64),
            nn.BatchNorm1d(64),
            nn.ReLU()
        )

        # ResNet-style blocks
        self.layer1 = self._make_layer(64, 64, 2)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)
        self.layer4 = self._make_layer(256, 512, 2, stride=2)

        # Final classification layer
        self.fc = nn.Linear(512, num_classes)

        self.apply(self._init_weights)

    def _make_layer(self, in_channels, out_channels, num_blocks, stride=1):
        layers = []
        layers.append(ResBlock(in_channels, out_channels, stride))
        for _ in range(1, num_blocks):
            layers.append(ResBlock(out_channels, out_channels))
        return nn.Sequential(*layers)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.initial(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.fc(x)
        return x


class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        self.conv1 = nn.Linear(in_channels, out_channels)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Linear(out_channels, out_channels)
        self.bn2 = nn.BatchNorm1d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Linear(in_channels, out_channels),
                nn.BatchNorm1d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
if __name__ == "__main__":
    knn_acc, linear_classifier_acc = main()
    print(f"\nFinal Output:")
    print(f"k-NN Accuracy: {knn_acc}")
    print(f"Linear Classifier Accuracy: {linear_classifier_acc}")