import argparse 
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import clip
from eva_vit import create_eva_vit_g
from processors.blip_processors import Blip2ImageTrainProcessor
from predefined_concepts import concepts 
import pickle
import multiprocessing
import logging
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
import time 
import json
import random

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, required=True)
    parser.add_argument('--epoch',type = int, default = 100)
    return parser.parse_args()

class ConceptDataset(Dataset):
    def __init__(self, dataset, dataset_path, image_dir, concepts, transform=None, precomputed_probs=None, mode='train'):
        self.dataset = dataset
        self.image_dir = image_dir
        self.concepts = concepts
        self.transform = transform        
        self.image_files = get_img_list(dataset_path, mode, dataset=dataset)
        precomputed_file = f"./precomputed/{dataset}_precomputed_probs.pkl"
        with open(precomputed_file, 'rb') as f:
            self.image_probs = pickle.load(f)

    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        image_path = os.path.join(self.image_dir, img_name)
        image = Image.open(image_path).convert('RGB')
        probs = self.image_probs[img_name]
        if self.transform:
            image = self.transform(image)
        return image, probs

class LoggingCallback(Callback):
    def __init__(self, logger):
        super().__init__()
        self.logger = logger
        
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if batch_idx % trainer.log_every_n_steps == 0:
            self.logger.info(f"Epoch {trainer.current_epoch}, Step {batch_idx}, Train Loss: {outputs['loss'].item():.4f}")
            
    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if batch_idx % trainer.log_every_n_steps == 0:
            self.logger.info(f"Epoch {trainer.current_epoch}, Step {batch_idx}, Val Loss: {outputs.item():.4f}")

class ConceptClassifier(pl.LightningModule):
    def __init__(self, num_concepts=100, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.visual_encoder = create_eva_vit_g(
            img_size=364,
            drop_path_rate=0,
            use_checkpoint=False,
            precision="fp16"
        )
        for param in self.visual_encoder.parameters():
            param.requires_grad = False
        self.visual_encoder.eval()
        self.visual_encoder.train = lambda mode: self.visual_encoder
        
        self.concept_classifier = nn.Sequential(
            nn.Linear(1408, 1024),
            nn.LayerNorm(1024),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(1024, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_concepts)
        )
        
        self.learning_rate = learning_rate

    def forward(self, x):
        with torch.no_grad():
            image_embeds = self.visual_encoder(x)
        image_embeds = image_embeds.mean(dim=1)
        return self.concept_classifier(image_embeds)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        logits = self(images)
        loss = F.kl_div(
            F.log_softmax(logits, dim=-1),
            labels,
            reduction='batchmean'
        )
        self.log('train_loss', loss, on_step=True, on_epoch=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        logits = self(images)
        loss = F.kl_div(
            F.log_softmax(logits, dim=-1),
            labels,
            reduction='batchmean'
        )
        self.log('val_loss', loss, on_step=True, on_epoch=True, sync_dist=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.concept_classifier.parameters(), lr=self.learning_rate)
        return optimizer

    def on_save_checkpoint(self, checkpoint):
        checkpoint['state_dict'] = {
            k: v for k, v in checkpoint['state_dict'].items()
            if k.startswith('concept_classifier') or k.startswith('ln_vision')
        }

    def on_load_checkpoint(self, checkpoint):
        state_dict = checkpoint['state_dict']
        model_state_dict = self.state_dict()
        for k in model_state_dict:
            if k not in state_dict and not (k.startswith('concept_classifier') or k.startswith('ln_vision')):
                state_dict[k] = model_state_dict[k]

def get_img_list(dataset_path, mode='train', dataset='flickr8k'):
    if dataset == 'coco':
        mode = 'coco_karpathy_' + mode 
    with open(os.path.join(dataset_path, f"{mode}.json"), 'r') as f:
        data = json.load(f)
    img_name = sorted(list(set([item["image"] for item in data])))

    return img_name

def main():
    args = parse_args()

    dataset = args.dataset 
    epoch = args.epoch 

    if dataset == "flickr8k":
        dataset_path = "/pretrained/lavis_cache/flickr8k/annotations/"
        image_dir = "/pretrained/lavis_cache/flickr8k/images/"
    elif dataset == "coco":
        dataset_path = "/pretrained/lavis_cache/coco/annotations"
        image_dir = "/pretrained/lavis_cache/coco/images/"
    elif dataset == "flickr30k":
        dataset_path = "/pretrained/lavis_cache/flickr30k/annotations"
        image_dir = "/pretrained/lavis_cache/flickr30k/images/"
    else:
        print(f"{dataset} is not supported")

    multiprocessing.set_start_method('spawn', force=True)
    torch.set_float32_matmul_precision('high')
    local_save_dir = f"./training_outputs_{dataset}_epoch_{epoch}_debug"
    os.makedirs(local_save_dir, exist_ok=True)
    
    logging.basicConfig(
        filename=os.path.join(local_save_dir, f'{dataset}_cc_training.log'),
        level=logging.INFO,
        format='%(asctime)s - %(message)s'
    )
    logger = logging.getLogger()
    
    vis_processor = Blip2ImageTrainProcessor()
    
    logger.info("loading dataset...")
    train_dataset = ConceptDataset(dataset, dataset_path, image_dir, concepts, transform=vis_processor, mode='train')
    val_dataset = ConceptDataset(dataset, dataset_path, image_dir, concepts, transform=vis_processor, mode='val')
    train_loader = DataLoader(
        train_dataset,
        batch_size=64,
        shuffle=True,
        num_workers=0,
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=64,
        shuffle=False,
        num_workers=0,
    )
    
    logger.info(f"Dataset loaded: {len(train_dataset)} training samples, {len(val_dataset)} validation samples")
    model = ConceptClassifier()
    checkpoint_callback = ModelCheckpoint(
        dirpath=os.path.join(local_save_dir, 'checkpoints'),
        filename='concept-classifier_best',
        save_top_k=1,  
        monitor='val_loss',
        mode='min',    
        save_last=True
    )
    
    logging_callback = LoggingCallback(logger)
    
    trainer = pl.Trainer(
        max_epochs=epoch,
        accelerator='gpu',
        devices=[0,1,2,3,4,5,6,7],
        callbacks=[checkpoint_callback, logging_callback],
        precision=16,
        strategy='ddp_find_unused_parameters_true',
        strategy='ddp_find_unused_parameters_true',
        default_root_dir=local_save_dir,
        logger=False,  
        num_sanity_val_steps=0  # Disable sanity check to prevent early shared memory overload

    )
    
    logger.info("start training...")
    trainer.fit(model, train_loader, val_loader)
    
    if checkpoint_callback.best_model_path:
        logger.info(f"Best model saved at: {checkpoint_callback.best_model_path}")
        logger.info(f"Best validation score: {checkpoint_callback.best_model_score:.4f}")
    else:
        logger.info("No best model checkpoint was saved")

if __name__ == "__main__":
    main()