import torch
import numpy as np
import os
from tqdm import tqdm
from rtpt import RTPT
from torch import nn
import glob

from neural_concept_binder import NeuralConceptBinder 
from CLEVR_Hans_image_dataset import CLEVRHansDataset
from sysbinder.sysbinder import SysBinderImageAutoEncoder
from argparser_precompute_encodings import get_parser

class EncodingPrecomputer:
    def __init__(self, args):
        self.args = args
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if not os.path.exists(self.args.result_dir):
            os.makedirs(self.args.result_dir)
        self.setup_model()

    def setup_model(self):
        """Initialize the appropriate model based on encoding type"""
        self.model = None
        if self.args.enc_type in ['concept_slot', 'one_hot', 'one_hot_padded', 'retrieval_corpus']:
            print('Loading Neural Concept Binder...')
            self.model = NeuralConceptBinder(self.args)
            self.model.to(self.device)
            self.model.eval()
            if self.args.enc_type == 'retrieval_corpus':
                print('Loading Retrieval Corpus...')
                self.retrieval_corpus = self.model.retrieval_corpus
        elif self.args.enc_type == 'block_slot':
            print('Loading SysBinder...')
            self.model = SysBinderImageAutoEncoder(self.args)
            if os.path.isfile(self.args.sysbinder_path):
                checkpoint = torch.load(self.args.sysbinder_path, map_location="cpu")
                try:
                    self.model.load_state_dict(checkpoint)
                except:
                    self.model.load_state_dict(checkpoint["model"])
                    self.model.image_encoder.sysbinder.prototype_memory.attn.temp = checkpoint["temp"]
                print(f"loaded ...{self.args.sysbinder_path}")
            else:
                raise FileNotFoundError("Model path for Sysbinder was not found.")
            self.model.to(self.device)
            self.model.eval()

    def load_images(self):
        """Load and prepare image datasets and dataloaders"""
        print('Loading images ...')
        # Create datasets and loaders for available splits
        loaders = {}
        
        for split in ['train', 'val', 'test']:
            try:
                dataset = CLEVRHansDataset(
                    self.args.data_dir, split, lexi=True
                )
                loader_args = {
                    'batch_size': self.args.batch_size,
                    'pin_memory': True,
                    'num_workers': self.args.num_workers,
                    'drop_last': False
                }
                loaders[split] = torch.utils.data.DataLoader(dataset, shuffle=False, **loader_args)
                print(f"Successfully loaded {split} dataset with {len(dataset)} samples")
            except FileNotFoundError as e:
                print(f"Warning: {split} dataset not found. Error: {e}")
                loaders[split] = None
        
        return loaders.get('train'), loaders.get('val'), loaders.get('test')

    @staticmethod
    def transform_to_onehot(attrs, model):
        """Create one-hot encodings for each block"""
        n_blocks = attrs.shape[2]
        attrs_one_hot = [
            torch.nn.functional.one_hot(
                attrs[:, :, block_id].long(),
                num_classes=model.prior_num_concepts[block_id]
            )
            for block_id in range(n_blocks)
        ]
        return torch.cat(attrs_one_hot, dim=2).type(torch.FloatTensor)

    def precompute_cs_and_onehot_encodings(self, loader):
        """Precompute concept-slot or one-hot encodings from images"""
        encs, encs_one_hot, class_ids, fnames_all = [], [], [], []
        rtpt = RTPT(name_initials='XX', experiment_name="Precompute Encodings",
                   max_iterations=len(loader))
        rtpt.start()

        for sample in tqdm(loader):
            fnames = sample[3]
            imgs, _, img_class_ids = map(lambda x: x.to(self.device), sample[:3])
            out = self.model.encode(imgs)
            out_one_hot = self.transform_to_onehot(out[0], self.model)

            encs.extend(out[0])
            encs_one_hot.extend(out_one_hot)
            class_ids.extend(img_class_ids)
            fnames_all.extend(fnames)
            rtpt.step()

        return (torch.stack(encs_one_hot), torch.stack(encs), 
                torch.stack(class_ids), np.array(fnames_all))

    def precompute_one_hot_padded_encs(self, cs_encodings):
        """Precompute one-hot padded encodings from concept-slot encodings"""
        one_hot_padded_encs = []
        block_size = max(self.model.prior_num_concepts)

        for cs_enc in cs_encodings:
            one_hot_padded_enc = []
            for slot in cs_enc:
                slot = torch.from_numpy(slot).to(torch.long)
                one_hot_per_slot = nn.functional.one_hot(slot, num_classes=block_size)
                one_hot_padded_enc.append(one_hot_per_slot)
            one_hot_tensor = torch.stack(one_hot_padded_enc, dim=0)
            one_hot_padded_encs.append(one_hot_tensor)
        return torch.stack(one_hot_padded_encs, dim=0)

    def precompute_block_slot_encs(self, loader):
        """Precompute block slot encodings from images"""
        encs, class_ids, fnames_all = [], [], []
        rtpt = RTPT(name_initials='XX', experiment_name="Gather Encs",
                   max_iterations=len(loader))
        rtpt.start()

        for sample in tqdm(loader):
            fnames = sample[3]
            imgs, _, img_class_ids = map(lambda x: x.to(self.device), sample[:3])
            out, _, _, _ = self.model.encode(imgs)

            # Handle memory issues
            out = out.detach().cpu()
            img_class_ids = img_class_ids.detach().cpu()

            encs.extend(out)
            class_ids.extend(img_class_ids)
            fnames_all.extend(fnames)
            rtpt.step()

            # Clear memory
            del out, imgs, img_class_ids
            torch.cuda.empty_cache()

        return torch.stack(encs), torch.stack(class_ids), np.array(fnames_all)

    def precompute_rc_encs(self, cs_encodings):
        """Precompute retrieval corpus encodings from concept-slot encodings"""
        rc_encs = []
        for cs_enc in cs_encodings:
            rc_enc = []
            for slot in cs_enc:
                rc_single_slot_enc = []
                for i in range(self.args.num_blocks):
                    block = self.retrieval_corpus[i]
                    prototype_id = int(slot[i])
                    rc_single_slot_enc.append(block['encs'][prototype_id])
                rc_single_slot_tensor = torch.stack(rc_single_slot_enc, dim=0)
                rc_enc.append(rc_single_slot_tensor)
            rc_enc_tensor = torch.stack(rc_enc, dim=0)
            rc_encs.append(rc_enc_tensor)
        return torch.stack(rc_encs, dim=0)

    def save_results(self, split, encodings, labels, fnames, suffix=''):
        """Save computed encodings and metadata to files"""
        np.save(f'{self.args.result_dir}/{split}_{suffix}_encs_{self.args.model_seed}.npy', 
               encodings.detach().cpu().numpy())
        np.save(f'{self.args.result_dir}/{split}_labels_{self.args.model_seed}.npy',
               labels.detach().cpu().numpy())
        np.save(f'{self.args.result_dir}/{split}_fnames_{self.args.model_seed}.npy',
               fnames)
    
def main():
    parser = get_parser()
    args = parser.parse_args()
    
    precomputer = EncodingPrecomputer(args)
    
    if args.enc_type in ['concept_slot', 'one_hot', 'block_slot']:
        # For these types we need to load the images
        train_loader, val_loader, test_loader = precomputer.load_images()
        
        # Process each available loader
        for split, loader in [('train', train_loader), ('val', val_loader), ('test', test_loader)]:
            if loader is None:
                print(f"Skipping {split} split as it was not found")
                continue
                
            if args.enc_type == 'concept_slot':
                print(f'Precomputing {split} concept slot encodings...')
                _, encs, labels, fnames = precomputer.precompute_cs_and_onehot_encodings(loader)
                
            elif args.enc_type == 'one_hot':
                print(f'Precomputing {split} one hot encodings...')
                encs, _, labels, fnames = precomputer.precompute_cs_and_onehot_encodings(loader)
                
            elif args.enc_type == 'block_slot':
                print(f'Precomputing {split} block slot encodings...')
                encs, labels, fnames = precomputer.precompute_block_slot_encs(loader)
                
            precomputer.save_results(split, encs, labels, fnames, suffix=args.enc_type)
    
    elif args.enc_type in ['one_hot_padded', 'retrieval_corpus']:
        print('Loading concept-slot encodings...')
        
        # Check which splits are available
        available_splits = []
        for split in ['train', 'test', 'val']:
            label_path = f'{args.data_dir}/{split}_labels_{args.model_seed}.npy'
            if os.path.exists(label_path):
                available_splits.append(split)
                # Copy labels and filenames
                for suffix in ['labels', 'fnames']:
                    source = np.load(f'{args.data_dir}/{split}_{suffix}_{args.model_seed}.npy')
                    np.save(f'{args.result_dir}/{split}_{suffix}_{args.model_seed}.npy', source)
        
        if not available_splits:
            print("Error: No splits found with concept-slot encodings")
            return
            
        # Process encodings
        compute_fn = (precomputer.precompute_one_hot_padded_encs if args.enc_type == 'one_hot_padded' 
                     else precomputer.precompute_rc_encs)
        
        print(f'Precomputing {args.enc_type} encodings...')
        for split in available_splits:
            cs_enc_paths = glob.glob(f'{args.data_dir}/{split}*encs*.npy')
            if not cs_enc_paths:
                print(f"Warning: No encoding file found for {split} split")
                continue
                
            cs_enc_path = cs_enc_paths[0]
            print(f"Processing {cs_enc_path}")
            cs_enc = np.load(cs_enc_path)
            encs = compute_fn(cs_enc)
            np.save(f'{args.result_dir}/{split}_{args.enc_type}_encs_{args.model_seed}.npy',
                   encs.detach().cpu().numpy())

if __name__ == "__main__":
    main()