from mi_estimator.mi_dataset import create_ssl_loader
from config import get_config
from model import set_model_from_config as set_summarizer_from_config
from mi_estimator.mi_model import set_model_from_config as set_mi_model_from_config
from tqdm import tqdm
import torch
from pathlib import Path
from pretrainer.mask_generator import MaskNet
import torch.nn.functional as F
from collections import defaultdict

# Print all arguments and GPU setting
def print_args(args):
    print(args.kwargs)
    print(f"CUDA: {torch.version.cuda}")
    print(f"cuDNN: {torch.backends.cudnn.version()}")
    if 'cuda' in args.device:
        print(f"GPU: {torch.cuda.is_available()}")
        print(f"GPU count: {torch.cuda.device_count()}")
        print(f"GPU name: {torch.cuda.get_device_name(0)}")


def get_mask_generator_from_config(config):
    summarizer = set_summarizer_from_config(config)
    mask_generator = MaskNet(summarizer,act=config.act)
    return mask_generator

def info_nce_loss(summary_feature, video_feature, temperature=0.1):
    """
    Computes a contrastive (InfoNCE) loss between the summary representation and the full video representation.
    Each summary should match its corresponding video.
    """
    # Normalize feature vectors.
    summary_feature = F.normalize(summary_feature, dim=-1)
    video_feature = F.normalize(video_feature, dim=-1)
    

    # Compute similarity matrix.
    logits = torch.matmul(summary_feature, video_feature.t()) / temperature
    batch_size = logits.size(0)
    labels = torch.arange(batch_size).to(logits.device)
    loss = F.cross_entropy(logits, labels, reduction='sum')
    return loss

def get_out_file(config):
    datasets = config.pt_datasets.split(',')
    penalty = f"ent{config.coef_ent}_size{config.coff_size}".replace(".","d")
    out_file = f'M{config.mask_type}_RO{config.read_out}_A{config.act}_RN{config.rand_neg}_D{str(config.Dropout_ratio).replace(".","d")}/bs{config.batch_size}_e{config.epochs}_tau{str(config.tau).replace(".","d")}_{penalty}.pt'
    out_file = str(Path(config.pt_ckpt_dir)/out_file)
    print(f"Output file: {out_file}")
    if Path(out_file).exists():
        print(f"Checkpoint already exists at {out_file}.")
        exit(17)
    if config.check:
        exit(0)
    return out_file

def run_pretrain(config):
    datasets = config.pt_datasets.split(',')
    out_file = get_out_file(config)

    mi_estimator = set_mi_model_from_config(config)
    mi_estimator.to(config.device)
    mi_estimator.train()

    mask_generator = get_mask_generator_from_config(config)
    mask_generator.to(config.device)
    mask_generator.train()
    
    params = list(mask_generator.parameters()) + list(mi_estimator.parameters())
    optimizer = torch.optim.Adam(params,lr=float(config.learning_rate),weight_decay=float(config.weight_decay))
    
    
    pt_loader = create_ssl_loader(datasets=datasets)
    batch_size = int(config.batch_size)

    epochs = config.epochs

    # tau_schedules = torch.linspace(10, 0.1, epochs).tolist()
    other_x = None
    for epoch in tqdm(range(1,epochs+1),total=epochs,ncols=70,leave=False,desc=f'Epoches'):
        mask_generator.train()
        mask_generator.summarizer.train()
        update_loss = 0.0
        batch = 0
        tau= config.tau
        losses_dict = defaultdict(float)

        batched_summary_embs = []
        batched_video_embs = []
        batched_rand_embs = []
        for feature,e_feature,dataset_name in pt_loader:
            # (1,T,D) for feature and (1,3,T,D) for e_feature
            feature = feature.to(config.device)
            e_feature = e_feature.to(config.device)

            if config.mask_type=="marg" and other_x is None:
                other_x = feature.detach().clone()
                continue

            # embeddings of videos generated summaries
            feat_mask, mask_losses = mask_generator.sample_mask(
                e_feature, 
                tau=tau,
            ) 
            masked_feat = mi_estimator.add_mask(feature, feat_mask, other_x) # (1,T,D)
            e_masked_feat = masked_feat.unsqueeze(1).expand(-1,3,-1,-1)  # (1,3,T,D)
            summary_emb = mi_estimator(e_masked_feat) # (1,D)
            batched_summary_embs.append(summary_emb)

            # embeddings of videos
            video_emb = mi_estimator(e_feature) # (1,D)
            batched_video_embs.append(video_emb)

            # embeddings of random summaries 
            if config.rand_neg:
                rand_mask = torch.rand_like(feat_mask)  # (1, T)
                rand_feat = mi_estimator.add_mask(feature, rand_mask, other_x) # (1,T,D)
                e_rand_feat = rand_feat.unsqueeze(1).expand(-1,3,-1,-1) # (1,3,T,D)
                rand_emb = mi_estimator(e_rand_feat) # (1,D)
                batched_rand_embs.append(rand_emb)

            if config.mask_type=="marg":
                other_x = feature.detach().clone()

            update_loss += mask_losses['ent'] * config.coef_ent
            losses_dict['ent'] += mask_losses['ent'].item() * config.coef_ent

            update_loss += mask_losses['size'] * config.coff_size
            losses_dict['size'] += mask_losses['size'].item() * config.coff_size
            
            batch += 1
            if batch==batch_size:
                batched_summary_embs = torch.stack(batched_summary_embs, dim=0) # (B,D)   
                batched_video_embs = torch.stack(batched_video_embs, dim=0) # (B,D)

                if config.rand_neg:
                    batched_rand_embs = torch.stack(batched_rand_embs, dim=0) # (B,D)
                    batched_compared_emb = torch.cat([batched_video_embs, batched_rand_embs], dim=0) # (2*B,D)
                else:
                    batched_compared_emb = batched_video_embs

                loss = info_nce_loss(batched_summary_embs, batched_compared_emb)
                losses_dict['loss'] += loss.item()
                loss.requires_grad_(True)
                update_loss += loss

                optimizer.zero_grad()
                update_loss = update_loss / batch
                update_loss.backward()
                optimizer.step()
                update_loss = 0.0
                batch = 0

                batched_summary_embs = []
                batched_video_embs = []
                batched_rand_embs = []
        # if epoch%(config.epochs//10)==0 or epoch==config.epochs:
        #     print(f"Epoch {epoch}/{epochs}, loss: {losses_dict['loss']/len(pt_loader):.8f}, ent: {losses_dict['ent']/len(pt_loader):.8f}, size: {losses_dict['size']/len(pt_loader):.8f}")
    
    Path(out_file).parent.mkdir(exist_ok=True,parents=True)
    torch.save({
        'state_dict': mask_generator.summarizer.state_dict(),
        'kwargs': config.kwargs,
    }, out_file)




if __name__=="__main__":
    # Load configurations
    config = get_config()

    # Print information of setting
    print_args(config)

    # run train
    run_pretrain(config)