import torch
import os
import argparse
from PIL import Image
from pathlib import Path
from torch.utils.data import Dataset
import lightning as L
import transformers
from lightning.fabric.strategies import DeepSpeedStrategy
import torch.nn.functional as F

cls_criterion = torch.nn.CrossEntropyLoss()

def batch_align(fabric, x):
    x = fabric.all_gather(x, sync_grads=True)
    return x.view(x.shape[0]*x.shape[1], -1)

class DLoader(Dataset):
    def __init__(self, text_list, old_processor, new_processor):
        self.text_list = text_list
        self.old_processor = old_processor
        self.new_processor = new_processor

    def __len__(self):
        return len(self.text_list)

    def _load_text(self, id):
        return self.text_list[id]
        
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        caption = self._load_text(idx)

        old_data = self.old_processor(text=caption, return_tensors="pt", truncation = True, padding = "max_length")
        new_data = self.new_processor(text=caption, return_tensors="pt", truncation = True, padding = "max_length")

        old_text = old_data['input_ids'][0]
        new_text = new_data['input_ids'][0]
        old_attn = old_data['attention_mask'][0]
        new_attn = new_data['attention_mask'][0]

        return old_text, new_text, old_attn, new_attn

def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

class MLPModel(torch.nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: int):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.LayerNorm(hidden_dim),
            torch.nn.GELU(),
            torch.nn.Dropout(p=dropout),
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.LayerNorm(hidden_dim),
            torch.nn.GELU(),
            torch.nn.Dropout(p=dropout),
            torch.nn.Linear(hidden_dim, output_dim),
        )
    def forward(self, x):
        return self.layers(x)

class OLDModel_CLIP(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward_text(self, input_ids, attn):
        text_outputs = self.model.text_model(
            input_ids=input_ids,
            attention_mask=attn
        )

        pooled_output = text_outputs[1]
        x_t = self.model.text_projection(pooled_output)

        return x_t

class NEWModel_CLIP(torch.nn.Module):
    def __init__(self, model, old_dim, new_dim, logit_scale_init_value):
        super().__init__()
        self.model = model
        self.Phi = MLPModel(input_dim=new_dim, hidden_dim=old_dim*4, output_dim=old_dim, dropout=0.0)
        self.logit_scale = logit_scale_init_value

    def forward_text(self, input_ids, attn):
        text_outputs = self.model.text_model(
            input_ids=input_ids,
            attention_mask=attn
        )
        pooled_output = text_outputs[1]
        x_t = self.model.text_projection(pooled_output)

        return x_t
    
def get_args_parser():
    parser = argparse.ArgumentParser('Text-only Training', add_help=False)
    parser.add_argument('--batch_size', default=1024, type=int, help='Batch size per GPU')
    parser.add_argument('--epochs', default=1, type=int)
    parser.add_argument('--text', default='datasets', type=str)
    parser.add_argument('--old_model', default='openai/clip-vit-base-patch32', type=str)
    parser.add_argument('--new_model', default='openai/clip-vit-large-patch14', type=str)
    parser.add_argument('--std', type=float, default=0.05)
    parser.add_argument('--weight_decay', type=float, default=0.01, 
                        help='weight decay (default: 0.01)')
    parser.add_argument('--init_lr', type=float, default=1e-4, metavar='LR')
    parser.add_argument('--output_dir', default='Phi',
                        help='path where to save, empty for no saving')
    parser.add_argument('--world_size', default=8, type=int,
                        help='number of distributed processes')

    return parser

def main(args):

    ds_config = {
    "train_micro_batch_size_per_gpu": args.batch_size,
    "zero_optimization": {"stage": 2},
    }

    fabric = L.Fabric(
        accelerator="cuda", 
        devices=args.world_size,
        strategy=DeepSpeedStrategy(config=ds_config), 
        precision="bf16"
    )
    fabric.launch()
    fabric.seed_everything(1337 + fabric.global_rank)

    if fabric.global_rank == 0:
        os.makedirs(args.output_dir, exist_ok=True)

    text_list = args.text

    with open(text_list, 'r') as file:
        text_list = file.readlines()

    old_name = args.old_model
    new_name = args.new_model

    old_config = transformers.AutoConfig.from_pretrained(old_name)
    new_config = transformers.AutoConfig.from_pretrained(new_name)

    args.new_model = args.new_model.lower()
    args.old_model = args.old_model.lower()

    new_dim = new_config.projection_dim
    old_dim = old_config.projection_dim

    with fabric.device:
        new_processor = transformers.AutoProcessor.from_pretrained(new_name)        
        new_model = transformers.CLIPModel.from_pretrained(new_name)
        for name, param in new_model.named_parameters():
            param.requires_grad = False
        new_model = NEWModel_CLIP(new_model, old_dim, new_dim, new_config.logit_scale_init_value).bfloat16()

        old_processor = transformers.AutoProcessor.from_pretrained(old_name)
        old_model = transformers.CLIPModel.from_pretrained(old_name)
        old_model = OLDModel_CLIP(old_model).bfloat16()

    dataset_train = DLoader(text_list, old_processor, new_processor, 0.0)
    
    train_loader = torch.utils.data.DataLoader(
        dataset_train, batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
        shuffle=True,
        collate_fn=collate_fn,
    )

    train_loader = fabric.setup_dataloaders(train_loader)

    optimizer = torch.optim.AdamW(old_model.parameters(), lr=args.init_lr, weight_decay=args.weight_decay)
    old_model, _ = fabric.setup(old_model, optimizer)

    optimizer = torch.optim.AdamW(new_model.parameters(), lr=args.init_lr, weight_decay=args.weight_decay)
    new_model, optimizer = fabric.setup(new_model, optimizer)
    
    train(fabric, old_model, new_model, optimizer, train_loader)

def train(
    fabric: L.Fabric,
    old_model: torch.nn.Module,
    new_model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    train_loader,
) -> None:
    
    old_model.train(False)
    new_model.train(True)

    iter = 0
    total_iter = len(train_loader) * args.epochs

    start_time = time.time()
    
    for epoch in range(args.epochs):
        optimizer.zero_grad()

        for samples in train_loader:
            old_text, new_text, old_attn, new_attn = samples

            with torch.no_grad():
                old_text = old_model.forward_text(old_text, old_attn)
                old_text = batch_align(fabric, old_text)

            new_text = new_model.forward_text(new_text, new_attn).detach()

            # add noise
            noise = torch.randn_like(new_text) * args.std
            new_text = new_model.Phi(F.normalize(new_text + noise.bfloat16()))
            new_text = batch_align(fabric, new_text)

            logit_scale = np.exp(new_model.logit_scale)

            ce_loss = clip_loss(logit_scale * F.normalize(new_text) @ F.normalize(old_text).t())
            loss = ce_loss

            fabric.print(f"epoch {epoch} iter {iter} ({(iter/total_iter)*100:.4f}%) clip_loss {ce_loss.item():.4f}")

            fabric.backward(loss, model=new_model)
            optimizer.step()
            
            iter += 1

    save_path = os.path.join(args.output_dir)
    fabric.save(save_path, {"model": new_model})
    fabric.barrier()

if __name__ == "__main__":
    args = get_args_parser()
    args = args.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)