import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision

import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# utils
from utils.dataset_loader import get_dataset
from utils.analysis import embedding_performance
from utils.eval_utils import load_snapshot
from utils.metrics import LinearProbeEval

# model
from models.simclr import SimCLR

import argparse
import yaml
import pandas as pd
import numpy as np
from tqdm import tqdm
from collections import namedtuple
import wandb

# set seed
torch.manual_seed(123)
torch.cuda.manual_seed(123)

def load_model(ssl_model, config, ckpt_path):
    if ckpt_path is None:
        ckpt_path = config['model']['ckpt_path']

    if ckpt_path is None:
        raise ValueError("ckpt_path not provided")
    
    snapshot = torch.load(ckpt_path, map_location='cuda', weights_only=True)
    ssl_model.load_state_dict(snapshot['MODEL_STATE'])
    print(f"Model loaded from {ckpt_path}")


if __name__ == "__main__":
    # parse arguments
    parser = argparse.ArgumentParser(description='SimCLR Linear Probing')
    parser.add_argument('--config', '-c', required=True, help='path to yaml config file')
    parser.add_argument('--ckpt_path', '-ckpt', 
                        default=None,
                        help='path to model checkpoint')
    parser.add_argument('--N', '-n', default=1, type=int,
                        help='number of samples for few-shot learning')
    parser.add_argument('--seed', '-s', default=1, type=int,)                   
    args = parser.parse_args()

    # load config file
    with open(args.config, 'r') as file:
        config = yaml.safe_load(file)

    # load config parameters
    experiment_name = config['experiment_name']
    method_type = config['method_type']
    supervision = config['supervision']

    dataset_name = config['dataset']['name']
    dataset_path = config['dataset']['path']

    encoder_type = config['model']['encoder_type']
    width_multiplier = config['model']['width_multiplier']
    hidden_dim = config['model']['hidden_dim']
    projection_dim = config['model']['projection_dim']
    pretrained = config['model']['pretrained']
    
    batch_size = config['linear']['batch_size']
    num_epochs = config['linear']['num_epochs']
    num_output_classes = config['linear']['num_output_classes']
    augment_both = config['linear']['augment_both']
    top_lr = float(config['linear']['top_lr'])
    momentum = float(config['linear']['momentum'])
    weight_decay = float(config['linear']['weight_decay'])
    track_performance = config['linear']['track_performance']
    save_every = int(config['linear']['save_every'])

    # get device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # get dataset
    train_dataset, train_loader, test_dataset, test_loader, train_labels, test_labels = get_dataset(dataset_name, dataset_path,
                            batch_size=batch_size, 
                            augment_both_views=augment_both,
                            test=True)
    # if dataset_name == 'imagenet':
    #     train_labels = train_dataset['label']
    #     test_labels = test_dataset['label']
    # else:
    #     train_labels = np.array(train_dataset.targets)
    #     test_labels = np.array(test_dataset.targets)
    
    # define model
    if encoder_type == 'resnet50':
        encoder = torchvision.models.resnet50(pretrained=False)
    else:
        raise NotImplementedError(f"{encoder_type} not implemented")
    
    if method_type == 'simclr':
        ssl_model = SimCLR(model=encoder,
                           dataset=dataset_name,
                           width_multiplier=width_multiplier,
                           hidden_dim=hidden_dim,
                           projection_dim=projection_dim,
                           track_performance=track_performance,
                        )
    else:
        raise NotImplementedError(f"{method_type} not implemented")
    
    # load_model(ssl_model, config, args.ckpt_path)
    load_snapshot(snapshot_path=args.ckpt_path,
                  model=ssl_model,
                  device=device)
    ssl_model.to(device)

    # avoid accidental gradient calculation
    for param in ssl_model.parameters():
        param.requires_grad = False

    # define settings
    Settings = namedtuple('Settings', ['device', 'num_output_classes', 
                                       'top_lr', 'momentum', 'weight_decay', 
                                       'epochs', 'save_every', 
                                       'track_performance'])
    settings = Settings(device=device, num_output_classes=num_output_classes,
                        top_lr=top_lr, momentum=momentum, weight_decay=weight_decay,
                        epochs=num_epochs, save_every=save_every, 
                        track_performance=track_performance)
    

    linear_evaluator = LinearProbeEval(
        ssl_model,
        train_loader,
        num_output_classes,
        501,
        top_lr,
        device,
        labels=None,
        log_every=100,
        log_to_wandb=False,
        wandb_project="linear-prob-eval",
        wandb_name="full-shot",
        train_labels=train_labels,
        test_labels=test_labels,
    )

    # N = [1, 5, 10, 20, 50, 100, 200, 500]
    N = [args.N]
    output_path = f'/home/understanding-ssl/logs/{dataset_name}/simclr/corollary1'
    output_logs_file = os.path.join(output_path, f'few_shot_lin_prob_{supervision}_new.csv')
    train_acc = []
    test_acc = []

    if os.path.exists(output_logs_file):
        few_shot_df = pd.read_csv(output_logs_file)
    else:
        few_shot_df = pd.DataFrame(columns=[
            'Number of Shots', 'Train Acc', 'Test Acc'
        ])
        print('Created a new dataframe for logging results.')

    for n_samples in N:
        if n_samples in few_shot_df['Number of Shots'].values:
            print(f"Evaluation exists for {n_samples} samples!")
            continue
        wandb_name = f'few-shot-{n_samples}'
        res, res_test = linear_evaluator.evaluate(test_loader, 
                                        n_samples=n_samples,
                                        repeat=2,
                                        embedding_layer=[1],
                                        wandb_name=wandb_name)
        
        train_acc.append(res)
        test_acc.append(res_test)

        new_row = {
            'Number of Shots': n_samples,
            'Train Acc': res,
            'Test Acc': res_test
        }

        few_shot_df = pd.concat([few_shot_df, pd.DataFrame([new_row])], ignore_index=True)

    # few_shot_df = pd.DataFrame({
    #     'Number of Shots': N,
    #     'Train Acc': train_acc,
    #     'Test Acc': test_acc
    # })

    few_shot_df.to_csv(output_logs_file, index=False)