import os
import pdb
import torch

import numpy as np
import wandb
from utils import *
from torch.utils.data import TensorDataset, DataLoader
from task_merger import get_merge_handler

base_path = "./data/heads"

import argparse

def get_LP_loader(dataloader, merged_model, num_classes, num_samples_per_task, device):
    all_data = []
    all_labels = []
    counts = defaultdict(int)
    base, rem = divmod(num_samples_per_task, num_classes)
    classes = random.sample(range(num_classes), num_classes)
    targets = {c: base + (1 if i < rem else 0) for i, c in enumerate(classes)}
    total = 0

    merged_model.eval()
    merged_model.to(merged_model.device)
    for (x, y) in dataloader:
        x = x.pixel_values.squeeze(1)
        y = y.tolist()
        keep = [i for i, label in enumerate(y) if counts[label] < targets[label]]
        if not keep:
            continue
        for i in keep:
            if total >= num_samples_per_task:
                break
            lbl = y[i]
            counts[lbl] += 1
            total += 1
            xi = x[i].unsqueeze(0).to(device)
            vision_outputs = merged_model(xi)
            vision_outputs = vision_outputs / vision_outputs.norm(dim=-1, keepdim=True)
            all_data.append(vision_outputs.cpu())
            all_labels.append(torch.tensor(lbl).cpu())
        if total >= num_samples_per_task:
            break
        
    all_data = torch.stack(all_data)
    all_labels = torch.stack(all_labels)

    train_dataset = TensorDataset(all_data, all_labels)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    return train_loader


def linear_probe(merged_model, zeroshot_head, dataloaders, task_idx, num_samples_per_task, device):
    print("Cacheing data...")
    cur_train_loader = dataloaders[task_idx]["train"]["full"]
    cur_val_loader = dataloaders[task_idx]["test"]["val"]

    num_classes = [196, 47, 10, 43, 10, 45, 397, 10]

    train_loader = get_LP_loader(cur_train_loader, merged_model, num_classes[task_idx], num_samples_per_task, device)
    val_loader = get_LP_loader(cur_val_loader, merged_model, num_classes[task_idx], num_samples_per_task, device)

    classifier = nn.Linear(train_loader.dataset[0][0].shape[-1], num_classes[task_idx]).to(device)
    classifier.bias.requires_grad = False
    classifier.weight.data = zeroshot_head

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)

    best_val_loss = float("inf")
    patience = 100
    counter = 0

    print("Linear probing...")
    for epoch in tqdm(range(2000)):
        with torch.enable_grad():
            classifier.train()
            for x_batch, y_batch in train_loader:
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                outputs = classifier(x_batch[:, 0])
                loss = loss_fn(outputs, y_batch)
                if args.wandb_project is not None:
                    wandb.log({f"LP_loss": loss.item()})
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

        val_loss = 0.0
        classifier.eval()
        for x_batch, y_batch in val_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            outputs = classifier(x_batch[:, 0])
            loss = loss_fn(outputs, y_batch)
            val_loss += loss.item()
        if args.wandb_project is not None:
            wandb.log({f"val_LP_loss": val_loss})
        classifier.train()
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_classifier = classifier.weight.data
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                break

    # Evaluate
    classifier.eval()
    all_preds = []
    all_labels = []
    classifier.weight.data = best_classifier
    with torch.no_grad():
        for x_batch, y_batch in val_loader:
            x_batch = x_batch.to(device)
            outputs = classifier(x_batch[:, 0])
            preds = torch.argmax(outputs, dim=1).cpu()
            all_preds.append(preds)
            all_labels.append(y_batch)

    acc = (torch.cat(all_labels) == torch.cat(all_preds)).float().mean()
    if args.wandb_project is not None:
        wandb.log({f"LP_train_accuracy": acc * 100})

    print(f"Validation accuracy linear probing = {acc:.4f}")

    return best_classifier

def run_BIG_function(args):
    set_seed(args.bigseed)
    # Use arguments
    print(f"Evaluating on split: {args.eval_split}")
    print(f"Using seed: {args.bigseed}")
    print(f"Using config: {args.config_name}")

    # CONFIG_NAME = 'vitB_r16_knots_ties'
    # CONFIG_NAME = 'vitL_r16_knots_ties'
    # CONFIG_NAME = 'vitL_r16_knots_dare_ties'
    # CONFIG_NAME = 'vit_b_r16_knots_dare_ties'
    print("Running with config: ", args.config_name)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    raw_config = get_config_from_name(args.config_name, device=device)

    for key, value in raw_config["task_merge_config"].items():
        parser.add_argument(f"--{key}", type=type(value), default=value)
    args_dict = {**vars(args), **vars(parser.parse_args())}
    args = argparse.Namespace(**args_dict)
    raw_config["task_merge_config"].update({
        key: args_dict[key]
        for key in raw_config["task_merge_config"]
        if key in args_dict
    })

    # Get clip encodings
    all_clip_encodings = [get_clip_encodings(f"{base_path}/{i['clip_encodings']}") for i in raw_config['dataset']]
    all_clip_encodings = [class_vectors / class_vectors.norm(dim=-1, keepdim=True) for class_vectors in all_clip_encodings]
    config = prepare_experiment_config(raw_config)

    if args.wandb_project is not None:
        wandb.init(entity="default", project=args.wandb_project, config=args)

    dataset_names = np.array([i['name'] for i in raw_config['dataset']])
    dataloaders = np.array([i for i in config['data']])
    
    # zeroshot acc [78.108, 79.271, 62.998, 52.113, 76.601, 74.225, 85.894, 59.96]
    # TODO: compute fine-tuned acc after linear probing for vitL
    if "large" in raw_config["model"]["base_type"]:
        fine_tuned_acc = {
            'stanford_cars' :99.76682729675113,
            'dtd' : 70.0531914893617,
            'eurosat' : 98.59259259259259,
            'gtsrb' : 97.19912905779889,
            'mnist' : 99.525,
            'resisc45' : 95.69841269841269,
            'sun397' : 79.59697732997482,
            'svhn' : 97.72399884759435,
        }
    else:
        # zeroshot acc 59.63, 43.883, 44.556, 32.185, 47.912, 60.667, 63.224, 31.739
        if args.use_linear_probing:
            # using single task vectors for each task, we here report
            # max(linear probing acc, original finetuned acc)
            fine_tuned_acc = {
                'stanford_cars' : 79.388,
                'dtd' : 69.521,
                'eurosat' : 99.0,
                'gtsrb' : 92.7,
                'mnist' : 99.3,
                'resisc45' : 88.4,
                'sun397' : 64.5,
                'svhn' : 96.2
            }
        else:
            fine_tuned_acc = {
                'stanford_cars' : 74.0,
                'dtd' : 58.3,
                'eurosat' : 99.0,
                'gtsrb' : 92.7,
                'mnist' : 99.3,
                'resisc45' : 88.4,
                'sun397' : 64.5,
                'svhn' : 96.2
            }
    
    print(raw_config['task_merge_config'])
    with torch.no_grad():
        all_results = config['task_merge_config']
        print('Creating Merge')
        # iniitalize merging function
        models = np.array([i.cpu() for i in config['models']['bases']])
        MergeClass = get_merge_handler(config['task_merge_config']['representation'])
        Merge = MergeClass(
                models, 
                pretrained_model=config['models']['new'], 
                param_handler=config['param_handler'],
                device=device,
                merge_config=config['task_merge_config'],
                dataloaders=dataloaders
            )
        Merge.transform(config['task_merge_config'])
        # set task scaling coefficients
        Merge.set_scaling_coeffs(config['task_merge_config']['scaling_coeffs'])
        merged_model = Merge.merge(config['task_merge_config'])
        print('Evaluate Merged Model on Each Dataset')
        print("Using config: ", config['task_merge_config'])
        avg_accuracy = 0.
        avg_norm_accuracy = 0.
        for i, loader_dict in enumerate(dataloaders):
            if args.use_linear_probing:
                all_clip_encodings[i] = linear_probe(merged_model, all_clip_encodings[i], dataloaders, i, num_samples_per_task=1880, device=device)

            loader = loader_dict['test'][args.eval_split]
            acc = evaluate_cliphead(merged_model.to(device), loader, class_vectors=all_clip_encodings[i].to(device), linear_probe=args.use_linear_probing)
            print(f"{dataset_names[i]} Normalized accuracy is {np.round((acc * 100)/ fine_tuned_acc[dataset_names[i]] *100, 3)}")
            print(f"{dataset_names[i]} accuracy is {np.round(acc * 100, 3)}")
            all_results[dataset_names[i]] = acc * 100
            all_results[dataset_names[i]+'_norm_acc'] = (acc * 100) / fine_tuned_acc[dataset_names[i]] *100
            avg_accuracy += acc * 100
            norm_acc = (acc * 100)/ fine_tuned_acc[dataset_names[i]] *100
            avg_norm_accuracy += norm_acc
            if args.wandb_project is not None:
                wandb.log({f"{dataset_names[i]}_accuracy": acc * 100, f"{dataset_names[i]}_norm_accuracy": norm_acc})

        avg_accuracy /= len(dataloaders)
        avg_norm_accuracy /= len(dataloaders)
        if args.wandb_project is not None:
            wandb.log({"avg_accuracy": avg_accuracy, "avg_norm_accuracy": avg_norm_accuracy})
        
        print(f'Average Accuracy is {np.round(avg_accuracy, 3)}')
        print(f'Average Normalized Accuracy is {np.round(avg_norm_accuracy, 3)}')
        all_results['Average_acc'] = avg_accuracy
        all_results['Average_norm_acc'] = avg_norm_accuracy
        all_results.update(config['task_merge_config'])
        print('Finished!')
    
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run BIG function with configurable parameters.")
    parser.add_argument("--eval_split", type=str, default="test", help="Evaluation split (e.g., test, val)")
    parser.add_argument("--bigseed", type=int, default=420, help="Random seed")
    parser.add_argument("--config_name", type=str, default="vitB_r16_knots_ties", help="Configuration name")
    parser.add_argument("--use-linear-probing", action="store_true", default=False)
    parser.add_argument("--wandb-project", type=str, default=None)

    args = parser.parse_known_args()[0]

    run_BIG_function(args)

