import os
import sys
import json
import numpy as np
import torch
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
from sklearn.metrics.pairwise import euclidean_distances

sys.path.append(os.path.abspath("."))
from task_tracker.CONFIG import current_risk

MODEL = "vicuna"
OUTPUT_DIR = "/guardrail/TaskTracker/store/model/" + current_risk + "/" + MODEL
os.makedirs(OUTPUT_DIR, exist_ok=True)

from task_tracker.training.dataset import ActivationsDatasetDynamic, ActivationsDatasetDynamicPrimaryText
from task_tracker.training.dataset import ActivationsDatasetDynamicReturnText
from task_tracker.training.helpers.data import load_file_paths
from task_tracker.training.utils.constants import CONSTANTS_ALL_MODELS, OOD_POISONED_FILE
from task_tracker.training.triplet_probe.loss_functions.triplet_loss import TripletLoss, triplet_mining_unique
from task_tracker.training.triplet_probe.models.processing_per_layer import ParallelConvProcessingModel
from task_tracker.training.utils.constants import MODEL_OUTPUT_DIR
from task_tracker.training.helpers.training import load_checkpoint, save_checkpoint, compute_ROC_AUC


LAYERS_PER_MODEL = {
    'llama3_70b': [0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79],
    'phi3': [0, 7, 15, 23, 31],
    'mixtral': [0, 7, 15, 23, 31],
    'mistral': [0, 7, 15, 23, 31],
    'llama3_8b': [0, 7, 15, 23, 31],
    'mistral_no_priming': [0, 7, 15, 23, 31],
    'vicuna': [0, 7, 15, 23, 31],
}

ACTIVATION_FILE_LIST_DIR, ACTIVATIONS_DIR, ACTIVATIONS_VAL_DIR = \
    CONSTANTS_ALL_MODELS[MODEL]['ACTIVATION_FILE_LIST_DIR'], CONSTANTS_ALL_MODELS[MODEL]['ACTIVATIONS_DIR'], CONSTANTS_ALL_MODELS[MODEL]['ACTIVATIONS_VAL_DIR']

config = {
    'activations': ACTIVATIONS_DIR,
    'activations_ood': ACTIVATIONS_VAL_DIR,
    'ood_poisoned_file': OOD_POISONED_FILE,
    'exp_name': 'metric_learning_' + MODEL,
    'margin': 0.3,
    'epochs': 6,
    'num_layers': (0, 5),  # start to end layer (both inclusive)
    'files_chunk': 10,
    'batch_size': 500,  # batch size used for triplet mining
    'learning_rate': 0.0005,
    'restart': False,  # Set to True if restarting from a checkpoint
    'feature_dim': 275,
    'pool_first_layer': 5 if MODEL == 'llama3_70b' else 3,  # llama3 has larger pool layer to reduce its dim faster
    'dropout': 0.5,
    'check_each': 50,
    'conv': True,
    'layer_norm': False,
    'delay_lr_factor': 0.95,
    'delay_lr_step': 800
}

# Ensure output directory exists
config['out_dir'] = os.path.join(MODEL_OUTPUT_DIR, f'{config["exp_name"]}')
os.makedirs(config.get('out_dir'), exist_ok=True)
with open(os.path.join(config.get('out_dir'), 'config.json'), 'w') as f:
    json.dump(config, f)

# Load training, test, and validation files
train_files_clean = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, 'train_clean_files_' + MODEL + '.txt'))
train_files_poisoned = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, 'train_poisoned_files_' + MODEL + '.txt'))
val_files_clean = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, 'val_clean_files_' + MODEL + '.txt'))
val_files_poisoned = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, 'val_poisoned_files_' + MODEL + '.txt'))
test_files_clean = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, 'test_clean_files_' + MODEL + '.txt'))
test_files_poisoned = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, 'test_poisoned_files_' + MODEL + '_3.txt'))

# Model, Optimizer, and Loss Function Setup
model = ParallelConvProcessingModel(feature_dim=config.get('feature_dim'), num_layers=config.get('num_layers'), conv=config.get('conv'), layer_norm=config.get('layer_norm'), pool_first_layer=config.get('pool_first_layer')).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=config.get('learning_rate'))
scheduler = StepLR(optimizer, step_size=config.get('delay_lr_step'), gamma=config.get('delay_lr_factor'))

# Triplet loss with hard and semi-hard mining
triplet_loss = TripletLoss(config.get('margin'))

# Optionally load from checkpoint
global_counter_for_save = 1
start_epoch = 0
best_roc_auc = 0

if config['restart']:
    checkpoint_file = os.path.join(config.get('out_dir'), f'epoch_model_{start_epoch}_checkpoint.pth')
    if os.path.exists(checkpoint_file):
        start_epoch, best_roc_auc = load_checkpoint(checkpoint_file, model, optimizer)
        print(f"Restarting from epoch {start_epoch} with best roc auc {best_roc_auc}")

def one_epoch_train(epoch_num, model, loss_function, optimizer, scheduler, train_files, config):
    global global_counter_for_save
    global best_roc_auc
    model.train()
    step = 0
    total_loss = 0
    total_batches = 0
    batch_size = 1024

    for i in range(0, len(train_files), config['files_chunk']):
        chunk_files = train_files[i: i + config['files_chunk']]
        dataset = ActivationsDatasetDynamicReturnText(chunk_files, config['activations'], TRAINING_TEXT_DIR, num_layers=config['num_layers'])
        training_loader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)

        for data in tqdm(training_loader):
            optimizer.zero_grad()

            # Triplet mining and loss computation
            step += 1
            with torch.no_grad():
                primary, clean, poisoned, text_batch = data[0].cuda(), data[1].cuda(), data[2].cuda(), data[3]
                print(primary.size())
                print(len(text_batch))
                # For models that are read as float16
                with torch.torch.autocast(device_type='cuda', dtype=torch.float32):
                    primary_embs = model(primary)
                    clean_embs = model(clean)
                    poisoned_embs = model(poisoned)
                triplet_combinations = triplet_mining_unique(primary_embs, clean_embs, poisoned_embs, config['margin'], text_batch, hard=True if global_counter_for_save > 3000 else False, step=step)
                print(len(triplet_combinations))

            for k in range(0, len(triplet_combinations), batch_size):
                # Extract embeddings based on mined indices
                indices_clean = triplet_combinations[k:k + batch_size, 0]
                indices_poisoned = triplet_combinations[k:k + batch_size, 1]

                # Batches of primary, clean, secondary
                anchor_embeddings = primary[indices_clean, :]
                positive_embeddings = clean[indices_clean, :]
                negative_embeddings = poisoned[indices_poisoned, :]

                # Forward
                # For models that are read as float16
                with torch.torch.autocast(device_type='cuda', dtype=torch.float32):
                    anchor_emb_output = model(anchor_embeddings)
                    positive_emb_output = model(positive_embeddings)
                    negative_emb_output = model(negative_embeddings)

                distance_ap = torch.nn.functional.cosine_similarity(anchor_emb_output, positive_emb_output, dim=-1)
                distance_an = torch.nn.functional.cosine_similarity(anchor_emb_output, negative_emb_output, dim=-1)
                # Calculate loss for the selected triplets
                loss = loss_function(distance_ap, distance_an, step)
                loss = loss * (anchor_emb_output.size(0) / batch_size)
                loss.backward()
                total_loss += loss.item()

                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                total_batches += 1

                global_counter_for_save += 1
                if global_counter_for_save % config['check_each'] == 0:
                    roc_auc = validation_ood(model, val_files_clean, val_files_poisoned, config)
                    if roc_auc > best_roc_auc:
                        print('=== New best model ===')
                        best_roc_auc = roc_auc
                        save_checkpoint({
                            'epoch': epoch_num,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'best_roc_auc': best_roc_auc
                        }, config['out_dir'], f'best_model_checkpoint.pth')

    avg_loss = (total_loss / total_batches) if total_batches > 0 else 0
    print(f'Epoch {epoch_num + 1}: Training Loss: {avg_loss}')

def compute_distances_validation(model, val_files, config):
    distances = []
    for i in range(0, len(val_files), config.get('files_chunk')):
        chunk_files = val_files[i: i + config.get('files_chunk')]
        dataset = ActivationsDatasetDynamicPrimaryText(chunk_files, num_layers=config['num_layers'], root_dir=config.get('activations_ood'))
        val_loader = DataLoader(dataset, batch_size=config.get('batch_size'), shuffle=False)

        for data in val_loader:
            primary, primary_with_text = [d.cuda() for d in data]
            # For models that are read as float16
            with torch.torch.autocast(device_type='cuda', dtype=torch.float32):
                primary_embs = model(primary)
                primary_with_text_embs = model(primary_with_text)
            cosine_similarity = torch.nn.functional.cosine_similarity(primary_embs, primary_with_text_embs, dim=-1)
            distances.extend(cosine_similarity.cpu().numpy())
    return distances

def validation_ood(model, val_files_clean, val_files_poisoned, config):
    model.eval()

    with torch.no_grad():  # Disable gradient computation
        distances_clean = compute_distances_validation(model, val_files_clean, config)
        distances_poisoned = compute_distances_validation(model, val_files_poisoned, config)

    distances_poisoned = process_val_data(config.get('ood_poisoned_file'), distances_poisoned)
    roc_auc = compute_ROC_AUC(distances_clean, distances_poisoned)

    print(f'Training step {global_counter_for_save + 1}: ROC AUC on OOD data: {roc_auc}')
    model.train()
    return roc_auc

for epoch in range(start_epoch, config.get('epochs')):
    one_epoch_train(epoch, model, triplet_loss, optimizer, scheduler, train_files_clean + train_files_poisoned, config)