import os
import argparse
import random
import time
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder, DatasetFolder
from torch.amp import GradScaler, autocast
from tqdm import tqdm
from datasets import load_dataset
from PIL import Image
os.environ["CUDA_VISIBLE_DEVICES"] = "2"


import open_clip

def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

class CLIP_ImageFolder_custom(DatasetFolder):
    def __init__(self, root, preprocess, tokenizer, captions, train_test='train'):
        self.root = root
        imagefolder_obj = ImageFolder(root)
        self.loader = imagefolder_obj.loader
        
        self.domain_target = root.split('/')[-1]
        self.preprocess = preprocess
        self.tokenizer = tokenizer
        self.cur_classes = imagefolder_obj.find_classes(root)[0]
        
        self.train_test = train_test 
        self.samples = np.array(imagefolder_obj.samples)
        self.captions = captions

    def __getitem__(self, index):
        path = self.samples[index][0]
        target = int(self.samples[index][1])
        
        sample = self.loader(path)
        sample = self.preprocess(sample)
        
        if self.train_test == 'train':
            mllm_hidden_states = self.captions[self.cur_classes[target]]['last'].squeeze(1)
            random_idx = np.random.choice(np.arange(len(mllm_hidden_states)), size=1)[0]
            mllm_hidden_states = mllm_hidden_states[random_idx]
            
            mllm_text = self.captions[self.cur_classes[target]]['text'][random_idx]
            mllm_text_tokens = self.tokenizer(mllm_text).squeeze(0)
            return sample, target, mllm_text_tokens, mllm_hidden_states
        else:
            caption = f"A photo of a {self.cur_classes[target]}."
            return sample, target, tokenized_caption, torch.tensor([0., 0.])
        
    def __len__(self):
        return len(self.samples)

class GCC_ImageFolder_custom(DatasetFolder):
    def __init__(self, root, preprocess, tokenizer):
        self.root = root
        self.samples = json.load(open(f'{self.root}/../metadata.json'))
        self.preprocess = preprocess
        self.tokenizer = tokenizer

    def __getitem__(self, index):
        sample = self.preprocess(Image.open(f"{self.root}/{self.samples[index]['image']}"))
        tokenized_caption = self.tokenizer(self.samples[index]['caption']).squeeze(0)
        return sample, tokenized_caption
        
    def __len__(self):
        return len(self.samples)

class CLIP_with_head(nn.Module):
    def __init__(self, clip_model, lshape=512):
        super(CLIP_with_head, self).__init__()
        self.clip = clip_model        
        self.mllm_proj_logit_scale = nn.Parameter(torch.ones(lshape) * np.log(1 / 0.07))
        self.domain_proj_logit_scale = nn.Parameter(torch.ones(lshape) * np.log(1 / 0.07))
        
        self.clip_projector = nn.Sequential(nn.Linear(768, lshape, bias=False))
        self.clip_text_projector = nn.Sequential(nn.Linear(512, lshape, bias=False))
        self.mllm_projector = nn.Sequential(nn.Linear(3072, lshape, bias=False))
        
    def forward(self, image, text):
        image, text, logits = self.clip(image, text)
        return image, text, logits
    
    def encode_all(self, image, text):
        image, text, logits = self.clip(image, text)
        return image, text, logits, self.mllm_proj_logit_scale.exp(), self.domain_proj_logit_scale.exp()

def fylp_loss(image, text, logit_scale, device):       
    logits_per_image = logit_scale * image @ text.T
    logits_per_text = logit_scale * text @ image.T
    cur_labels = torch.arange(len(image), device=device, dtype=torch.long)
    return (F.cross_entropy(logits_per_image, cur_labels) + F.cross_entropy(logits_per_text, cur_labels)) / 2

def disentangle_loss(x, y):
    x = (x - x.mean(0)) / (x.std(0) + 1e-4)
    y = (y - y.mean(0)) / (y.std(0) + 1e-4)
    crossCorMat = (x @ y.T) / len(x)
    return torch.diagonal(crossCorMat).pow(2).sum()

def MLLM_CLIP_loss(x_, y_, mllm_proj_logit_scale, device):
    y_ = F.normalize(y_, dim=-1)
    x_ = F.normalize(x_, dim=-1)
    cur_labels = torch.arange(len(x_), device=device, dtype=torch.long)
    x_per_y = mllm_proj_logit_scale * x_ @ y_.T
    y_per_x = mllm_proj_logit_scale * y_ @ x_.T
    return (F.cross_entropy(x_per_y, cur_labels) + F.cross_entropy(y_per_x, cur_labels)) / 2

def get_dataloader(args, cur_source_domains, preprocess, tokenizer, captions, return_ds=False):
    all_train_ds = []
    for domain in cur_source_domains:
        train_ds = CLIP_ImageFolder_custom(root=f"{args.datadir}/{args.dataset}/{domain}", 
                                           preprocess=preprocess,
                                           tokenizer=tokenizer,
                                           captions=captions)
        all_train_ds.append(train_ds)

    all_train_ds = torch.utils.data.ConcatDataset(all_train_ds)
    print(f"Num samples total: {len(all_train_ds)}")

    dataloader = torch.utils.data.DataLoader(dataset=all_train_ds, 
                                             batch_size=args.batch_size, 
                                             num_workers=args.num_workers, 
                                             shuffle=True, 
                                             pin_memory=True, 
                                             drop_last=True,
                                             persistent_workers=True,
                                             prefetch_factor=4)
    if return_ds:
        return dataloader, all_train_ds
    return dataloader

def get_batch(dataloader_iter, dataloader):
    try:
        batch = next(dataloader_iter)
    except StopIteration:
        dataloader_iter = iter(dataloader)
        batch = next(dataloader_iter)
    return batch, dataloader_iter

def train(args, clip_model, tokenizer, gcc_dataloader, diffusion_dataloader, imagenet_dataloader, device):
    scaler = GradScaler('cuda')
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, clip_model.parameters()), lr=args.lr, weight_decay=args.wd)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
    
    activation = {}
    def get_embeddings_image(name):
        def hook(model, input, output):
            activation[name] = clip_model.clip.visual._global_pool(output.clone())[0]
        return hook
    clip_model.clip.visual.ln_post.register_forward_hook(get_embeddings_image('ln_post'))

    diffusion_dataloader_iter = iter(diffusion_dataloader)
    gcc_dataloader_iter = iter(gcc_dataloader)
    imagenet_dataloader_iter = iter(imagenet_dataloader)
    
    num_steps = 5000
    for epoch in range(args.epochs):
        start_time = time.time()
        total_loss, dg_loss_acc, dm_loss_acc, gcc_loss_acc, inet_loss_acc = 0.0, 0.0, 0.0, 0.0, 0.0
        
        tqdm_object = tqdm(range(num_steps), total=num_steps, desc=f"Epoch {epoch+1}/{args.epochs}")
        for _ in tqdm_object:
            optimizer.zero_grad()
            with autocast('cuda', dtype=torch.bfloat16):
                loss = 0
                
                if args.diffusion_loss:
                    diffusion_images, diffusion_mllm_tokenized_text, diffusion_mllm_hidden_states, diffusion_dataloader_iter = get_batch(diffusion_dataloader_iter, diffusion_dataloader)
                    diffusion_images, diffusion_mllm_hidden_states, diffusion_mllm_tokenized_text = \
                        diffusion_images.to(device, non_blocking=True), diffusion_mllm_hidden_states.to(device, non_blocking=True), diffusion_mllm_tokenized_text.to(device, non_blocking=True)
                    
                    embeddings = clip_model.encode_all(diffusion_images, diffusion_mllm_tokenized_text)
                    diffusion_image_embeddings, diffusion_text_embeddings, _, mllm_proj_logit_scale, domain_proj_logit_scale = embeddings
                    
                    mllm_head_embeddings = clip_model.mllm_projector(diffusion_mllm_hidden_states)
                    image_domain_head_embeddings = clip_model.clip_projector(activation['ln_post'].requires_grad_(True))
                    
                    diff_mllm_loss = MLLM_CLIP_loss(mllm_head_embeddings, diffusion_text_embeddings, mllm_proj_logit_scale, device) * args.mllm_loss_weight
                    diff_domain_loss = fylp_loss(image_domain_head_embeddings, diffusion_text_embeddings, domain_proj_logit_scale, device) * args.clip_domain_weight
                    
                    diff_dist_loss_image_image = disentangle_loss(diffusion_image_embeddings, image_domain_head_embeddings) * args.diag_image_image_loss_weight
                    diff_dist_loss_text_image = disentangle_loss(diffusion_text_embeddings, diffusion_image_embeddings) * args.diag_text_text_loss_weight
                    diff_dist_loss = (diff_dist_loss_image_image + diff_dist_loss_text_image) / 2
                    
                    loss += diff_mllm_loss + diff_domain_loss + diff_dist_loss
                    dg_loss_acc += diff_dist_loss.item()
                    dm_loss_acc += (diff_mllm_loss + diff_domain_loss).item()

                if args.imagenet_loss:
                    imagenet_images, imagenet_captions, imagenet_dataloader_iter = get_batch(imagenet_dataloader_iter, imagenet_dataloader)
                    imagenet_images, imagenet_captions = imagenet_images.to(device), imagenet_captions.to(device)

                    image, text, logit_scale = clip_model.clip(imagenet_images, imagenet_captions)
                    inet_loss = fylp_loss(image, text, logit_scale, device)
                    
                    image_domain_head_embeddings = clip_model.clip_projector(activation['ln_post'].requires_grad_(True))
                    inet_dist_loss_image_image = disentangle_loss(image, image_domain_head_embeddings) * args.diag_image_image_loss_weight
                    
                    loss += inet_loss + inet_dist_loss_image_image
                    inet_loss_acc += inet_loss.item()
                    dg_loss_acc += inet_dist_loss_image_image.item()

                if args.gcc_loss:
                    gcc_images, gcc_captions, gcc_dataloader_iter = get_batch(gcc_dataloader_iter, gcc_dataloader)
                    gcc_images, gcc_captions = gcc_images.to(device), gcc_captions.to(device)

                    image, text, logits = clip_model.clip(gcc_images, gcc_captions)
                    gcc_loss_val = fylp_loss(image, text, logits, device) * args.gcc_loss_weight
                    
                    image_domain_head_embeddings = clip_model.clip_projector(activation['ln_post'].requires_grad_(True))
                    gcc_dist_loss_image_image = disentangle_loss(image, image_domain_head_embeddings) * args.diag_image_image_loss_weight
                    
                    loss += gcc_loss_val + gcc_dist_loss_image_image
                    gcc_loss_acc += gcc_loss_val.item()
                    dg_loss_acc += gcc_dist_loss_image_image.item()

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                total_loss += loss.item()
                
                tqdm_object.set_postfix(loss=loss.item())

        scheduler.step()
        if (epoch + 1) % args.save_every == 0:
            torch.save(clip_model.clip.state_dict(), f'{args.weights_dir}/UNLEARN_{args.weight_name}_{epoch+1}e.pth')
            
        print(f"Epoch: {epoch+1} | Loss: {total_loss/num_steps:.3f} | DG: {dg_loss_acc/num_steps:.3f} | DM: {dm_loss_acc/num_steps:.3f} | GCC: {gcc_loss_acc/num_steps:.3f} | INET: {inet_loss_acc/num_steps:.3f} | Time: {time.time() - start_time:.3f}s")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='domainnet_loo_test')
    parser.add_argument('--datadir', type=str, default='../../data')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--num_workers', type=int, default=6)
    parser.add_argument('--epochs', type=int, default=80)
    parser.add_argument('--lr', type=float, default=1e-5)
    parser.add_argument('--wd', type=float, default=0.1)
    parser.add_argument('--mllm_loss_weight', type=float, default=1e-4)
    parser.add_argument('--clip_domain_weight', type=float, default=1e-4)
    parser.add_argument('--diag_text_text_loss_weight', type=float, default=1e-4)
    parser.add_argument('--diag_image_image_loss_weight', type=float, default=1e-4)
    parser.add_argument('--gcc_loss_weight', type=float, default=1.0)
    parser.add_argument('--no_diffusion_loss', action='store_false', dest='diffusion_loss')
    parser.add_argument('--no_gcc_loss', action='store_false', dest='gcc_loss')
    parser.add_argument('--no_imagenet_loss', action='store_false', dest='imagenet_loss')
    parser.add_argument('--weight_name', type=str, default='run1')
    parser.add_argument('--weights_dir', type=str, default='./weights')
    parser.add_argument('--save_every', type=int, default=10)
    parser.add_argument('--caption_file', type=str, default='../../data/text_and_embedding_data/text_and_embedding_gemini_pro.json')
    parser.add_argument('--gcc_metadata', type=str, default='../../data/eleven/Pretrain_CC3M/metadata.json')
    parser.add_argument('--gcc_dir', type=str, default='../../data/eleven/Pretrain_CC3M/images')
    parser.add_argument('--diffusion_dir', type=str, default='../new_diffusion_images')
    parser.add_argument('--pretrained_weights', type=str, default='../weights/CLIP_GCC_UNLEARN_20e.pt')
    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    set_seed(args.seed)
    torch.set_num_threads(12)
    os.makedirs(args.weights_dir, exist_ok=True)

    with open(args.caption_file, "r") as f:
        captions = json.load(f)
    for folder in captions:
        cur_list = torch.tensor(captions[folder]['last'])
        cur_list = F.normalize(cur_list, dim=-1)
        if len(cur_list.shape) == 1:
            cur_list = cur_list.unsqueeze(0)
        captions[folder]['last'] = cur_list

    model, _, preprocess_clip = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai')
    tokenizer = open_clip.get_tokenizer('ViT-B-32')
    clip_model = CLIP_with_head(model).to(device)
    if os.path.exists(args.pretrained_weights):
        clip_model.clip.load_state_dict(torch.load(args.pretrained_weights))
    clip_model = torch.compile(clip_model)

    # Dataloaders
    _, diffusion_dataset = get_dataloader(argparse.Namespace(datadir=".", dataset=".", batch_size=args.batch_size, num_workers=args.num_workers), 
                                          [args.diffusion_dir], preprocess_clip, tokenizer, captions, return_ds=True)
    diffusion_dataloader = DataLoader(dataset=diffusion_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True, drop_last=True, persistent_workers=True)
    
    gcc_dataset = GCC_ImageFolder_custom(args.gcc_dir, preprocess_clip, tokenizer)
    gcc_dataloader = DataLoader(gcc_dataset, shuffle=True, num_workers=args.num_workers, batch_size=args.batch_size, drop_last=True, prefetch_factor=4)
    
    imagenet_dataset = load_dataset('imagenet-1k', split='train')
    def transform_dataset(example):
        example["image"] = [preprocess_clip(image) for image in example['image']]
        example["label"] = [f"a photo of a {imagenet_dataset.features['label'].int2str(label)}" for label in example["label"]]
        return example
    imagenet_dataset.set_format("torch")
    imagenet_dataset = imagenet_dataset.with_transform(transform_dataset)
    imagenet_dataloader = DataLoader(imagenet_dataset, shuffle=True, num_workers=args.num_workers, batch_size=args.batch_size, drop_last=True, prefetch_factor=4)

    train(args, clip_model, tokenizer, gcc_dataloader, diffusion_dataloader, imagenet_dataloader, device)
    
    torch.save(clip_model.clip.state_dict(), f'{args.weights_dir}/UNLEARN_{args.weight_name}_final.pth')

if __name__ == "__main__":
    main()
