import torch
import os
from tqdm import tqdm
import torch.nn.functional as F
from AlignModule import AlignModule
from SVMugenDataset import SVMugenDataset
import argparse

def contrastive_loss(logits):
    neg_ce = torch.diag(logits)
    return -neg_ce.mean()

def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
    caption_loss = contrastive_loss(similarity)
    image_loss = contrastive_loss(similarity)
    return (caption_loss + image_loss) / 2.0

def build_loaders(args, split):
    dataset = SVMugenDataset(args, split, save_video=args.save_video, video_save_dir=args.video_save_dir)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=True if split == "train" else False,
        drop_last=True if split == "train" else False,
        collate_fn = dataset.collate_fn,
    )

    return dataloader


class Trainer():

    def __init__(self, args, train_loader, valid_loader):
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        if args.load:
            model_path = os.path.join(args.model_save_dir, args.model_name)
            if not os.path.exists(model_path):
                print(model_path)
            assert(os.path.exists(model_path))
            self.model = AlignModule(batch_size=args.batch_size, video_enc=True, audio_enc=False, text_enc=True,
                            pretrained=args.pretrained, provenance=args.provenance, scl_filename=args.scl_filename,
                            train_top_k=args.train_top_k, test_top_k=args.test_top_k,
                            trainable=args.trainable, text_embedding=768, debug=args.debug_prov,
                            video_decoder_layers=args.video_decoder_layers, text_decoder_layers=args.text_decoder_layers,
                            multi_text=args.multi_text, load_path=model_path, gt_text=args.use_text_gt, save_pred=args.save_pred, pred_save_dir=args.pred_save_dir).to(args.device)
        else:
            self.model = AlignModule(batch_size=args.batch_size, video_enc=True, audio_enc=False, text_enc=True,
                            pretrained=args.pretrained, provenance=args.provenance, scl_filename=args.scl_filename,
                            train_top_k=args.train_top_k, test_top_k=args.test_top_k,
                            trainable=args.trainable, text_embedding=768, debug=args.debug_prov,
                            video_decoder_layers=args.video_decoder_layers, text_decoder_layers=args.text_decoder_layers,
                            multi_text=args.multi_text, gt_text=args.use_text_gt, save_pred=args.save_pred, pred_save_dir=args.pred_save_dir).to(args.device)
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.constraint_violation_loss = F.binary_cross_entropy
        self.match_loss = clip_loss
        self.device = args.device
        self.constraint_weight = args.constraint_weight
        self.alternative_train = args.alternative_train

    def accuracy(self, y_pred, y):
        batch_size = len(y)
        # pred = torch.argmax(y_pred, dim=1)
        # gt = torch.argmax(y, dim=1)

        y = torch.arange(len(y_pred)).to(y_pred.device)

        img2cap_match_idx = y_pred.argmax(dim=1)
        cap2img_match_idx = y_pred.argmax(dim=0)

        img_acc = sum(img2cap_match_idx == y)
        cap_acc = sum(cap2img_match_idx == y)

        # num_correct = len([() for i, j in zip(pred, gt) if i == j])
        return (img_acc, cap_acc, batch_size)

    def train_epoch(self, epoch):
        if self.alternative_train:
            self.model.toggle_training_model()
        else:
            self.model.train()

        total_loss = []
        total_img_correct = 0
        total_text_correct = 0
        total_count = 0

        iterator = tqdm(self.train_loader)
        for i, batch in enumerate(iterator):
            if self.alternative_train:
                self.model.toggle_training_model()
            self.optimizer.zero_grad()
            batch = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            batch_size = len(batch['text_idx'])
            pred_match, pred_constraint_violation = self.model(batch)
            ground_truth = torch.diag(torch.tensor([1.0] * batch_size)).to(self.device)
            constraint_violation = torch.zeros(batch_size, batch_size).to(self.device)

            loss = self.constraint_violation_loss(pred_match, ground_truth) + self.constraint_weight * self.constraint_violation_loss(pred_constraint_violation, constraint_violation)
            # loss = self.match_loss(pred_match) + self.constraint_weight * self.constraint_violation_loss(pred_constraint_violation, constraint_violation)
            # loss = self.loss(pred_match, ground_truth) + self.constraint_weight * self.loss(pred_constraint_violation, constraint_violation)
            loss.backward()
            self.optimizer.step()

            img_acc, cap_acc, batch_size = self.accuracy(pred_match, ground_truth)
            total_loss.append(loss.item())
            total_img_correct += img_acc
            total_text_correct += cap_acc

            total_count += batch_size
            avg_loss = sum(total_loss) / (i + 1)
            correct_img_perc = (total_img_correct / total_count) * 100.0
            correct_text_perc = (total_text_correct / total_count) * 100.0

            iterator.set_description(f"[Train Epoch {epoch}] Avg Loss: {avg_loss}, Video Accu: {total_img_correct}/{total_count} ({correct_img_perc:.2f}%), Text Accu: {total_text_correct}/{total_count} ({correct_text_perc:.2f}%)")

        return avg_loss, correct_img_perc, correct_text_perc

    def eval_epoch(self, epoch):
        self.model.eval()
        total_loss = []
        total_img_correct = 0
        total_text_correct = 0
        total_count = 0

        with torch.no_grad():
            iterator = tqdm(self.train_loader)
            for i, batch in enumerate(iterator):
                batch = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
                batch_size = len(batch['text_idx'])
                pred_match, pred_constraint_violation = self.model(batch)
                ground_truth = torch.diag(torch.tensor([1.0] * batch_size)).to(self.device)
                constraint_violation = torch.zeros(batch_size, batch_size).to(self.device)
                # loss = self.match_loss(pred_match) + self.constraint_weight * self.constraint_violation_loss(pred_constraint_violation, constraint_violation)
                loss = self.constraint_violation_loss(pred_match, ground_truth) + self.constraint_weight * self.constraint_violation_loss(pred_constraint_violation, constraint_violation)

                # loss = self.loss(pred_match, ground_truth) + self.constraint_weight * self.loss(pred_constraint_violation, constraint_violation)
                img_acc, cap_acc, batch_size = self.accuracy(pred_match, ground_truth)
                total_loss.append(loss.item())
                total_img_correct += img_acc
                total_text_correct += cap_acc

                total_count += batch_size
                avg_loss = sum(total_loss) / (i + 1)
                correct_img_perc = (total_img_correct / total_count) * 100.0
                correct_text_perc = (total_text_correct / total_count) * 100.0

                iterator.set_description(f"[Test Epoch {epoch}] Avg Loss: {avg_loss}, Video Accu: {total_img_correct}/{total_count} ({correct_img_perc:.2f}%), Text Accu: {total_text_correct}/{total_count} ({correct_text_perc:.2f}%)")

        return avg_loss, correct_img_perc, correct_text_perc

    def test_epoch(self):
        self.model.eval()
        total_loss = []
        total_img_correct = 0
        total_text_correct = 0
        total_count = 0

        with torch.no_grad():
            iterator = tqdm(self.train_loader)
            for i, batch in enumerate(iterator):
                batch = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
                batch_size = len(batch['text_idx'])
                pred_match, pred_constraint_violation = self.model.predict(batch)
                ground_truth = torch.diag(torch.tensor([1.0] * batch_size)).to(self.device)
                constraint_violation = torch.zeros(batch_size, batch_size).to(self.device)
                loss = self.constraint_violation_loss(pred_match, ground_truth) + self.constraint_weight * self.constraint_violation_loss(pred_constraint_violation, constraint_violation)

                # loss = self.loss(pred_match, ground_truth) + self.constraint_weight * self.loss(pred_constraint_violation, constraint_violation)
                img_acc, cap_acc, batch_size = self.accuracy(pred_match, ground_truth)
                total_loss.append(loss.item())
                total_img_correct += img_acc
                total_text_correct += cap_acc

                total_count += batch_size
                avg_loss = sum(total_loss) / (i + 1)
                correct_img_perc = (total_img_correct / total_count) * 100.0
                correct_text_perc = (total_text_correct / total_count) * 100.0

                iterator.set_description(f"[Test] Avg Loss: {avg_loss}, Video Accu: {total_img_correct}/{total_count} ({correct_img_perc:.2f}%), Text Accu: {total_text_correct}/{total_count} ({correct_text_perc:.2f}%)")

        return avg_loss, correct_img_perc, correct_text_perc

    def train(self):

        best_loss = float('inf')

        for epoch in range(args.epochs):
            train_avg_loss, train_correct_img_perc, train_correct_img_perc = self.train_epoch(epoch)
            val_avg_loss, val_correct_img_perc, val_correct_text_perc = self.eval_epoch(epoch)
            self.model.save(os.path.join(args.model_save_dir, f"latest_checkpoint.pt"))

            if val_avg_loss < best_loss:
                best_loss = val_avg_loss
                self.model.save(os.path.join(args.model_save_dir, "best_checkpoint.pt"))
                print("Saved Best Model!")

    def test(self):
        self.eval_epoch(0)

def parse_args():
    parser = argparse.ArgumentParser()
    parser = SVMugenDataset.add_data_specific_args(parser)
    parser.add_argument('--model_name', type=str, default="best_checkpoint.pt")
    parser.add_argument('--batch_size', type=int, default=3)
    parser.add_argument('--num_workers', type=int, default=1)
    parser.add_argument('--default_root_dir', type=str, default='saved_checkpoints')
    parser.add_argument('--load', action='store_true')

    parser.add_argument('--train_data_ct', type=int, default=5000)
    parser.add_argument('--test_data_ct', type=int, default=1000)

    parser.add_argument('--lr', type=float, default=0.0001)
    parser.add_argument('--constraint_weight', type=float, default=0.01)

    parser.add_argument('--weight_decay', type=float, default=0.001)
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--alternative_train', action='store_true')

    parser.add_argument('--video_enc', action='store_true')
    parser.add_argument('--audio_enc', action='store_true')
    parser.add_argument('--text_enc', action='store_true')
    parser.add_argument('--multi_text', action='store_true')
    parser.add_argument('--pretrained', action='store_true')
    parser.add_argument('--trainable', action='store_true')
    parser.add_argument('--use_cuda', action='store_true')
    parser.add_argument('--debug_prov', action='store_true')
    parser.add_argument('--use_text_gt', action='store_true')

    parser.add_argument('--gpu', type=int, default=0)

    parser.add_argument('--provenance', type=str, default="diffaddmultprob")
    parser.add_argument('--train_top_k', type=int, default=5)
    parser.add_argument('--test_top_k', type=int, default=5)
    parser.add_argument('--video_decoder_layers', type=int, default=2)
    parser.add_argument('--text_decoder_layers', type=int, default=2)
    parser.add_argument('--scl_type', type=str, default="action")
    parser.add_argument('--folder_name', type=str, default=None)
    parser.add_argument('--save_video', type=bool, default=False)
    parser.add_argument('--save_video_dir', type=str, default=None)

    parser.add_argument('--alternative_train_freq', type=int, default=10)
    parser.add_argument('--seed', type=int, default=None)
    parser.add_argument('--phase', type=str, default="test")
    parser.add_argument('--save_pred', type=bool, default=True)

    args = parser.parse_args()
    args.text_enc = True
    args.video_enc = True
    args.audio_enc = False
    args.trainable = True
    args.pretrained = True
    args.get_audio = False
    args.get_text_desc = True
    args.use_manual_annotation = False
    args.use_auto_annotation = True
    args.get_game_frame = True
    args.use_cuda = True
    args.debug = True
    args.debug_prov = False
    args.multi_text=False
    args.use_text_gt = True
    args.alternative_train=False
    args.use_text_gt = True
    if args.phase == "test":
        args.load=True
        args.save_pred=True
        args.save_video=True
        args.batch_size=16
    if args.debug_prov:
        args.batch_size=1

    if args.scl_type == "action":
        args.scl_filename = "mugen_contrast_learning_action.scl"
    elif args.scl_type == "horizontal":
        args.scl_filename = "mugen_contrast_learning_horizontal.scl"
    elif args.scl_type == "direction":
        args.scl_filename = "mugen_contrast_learning_direction.scl"
    else:
        raise "Unknown scl type"

    args.default_root_dir =  os.path.abspath(os.path.join(os.path.abspath(__file__), "../models"))
    if args.folder_name is None:
        args.folder_name = ""
    args.device = f"cuda:{args.gpu}" if args.use_cuda else "cpu"
    args.model_save_dir = os.path.join(args.default_root_dir, args.folder_name)
    args.pred_save_dir = os.path.join(args.model_save_dir, 'pred')
    args.video_save_dir = os.path.join(args.model_save_dir, 'video')

    if not os.path.exists(args.model_save_dir):
        os.mkdir(args.model_save_dir)
    if not os.path.exists(args.pred_save_dir):
        os.mkdir(args.pred_save_dir)
    if not os.path.exists(args.video_save_dir):
        os.mkdir(args.video_save_dir)

    if args.seed is not None:
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)

    print(args)
    return args


if __name__ == "__main__":

    args = parse_args()
    train_loader = build_loaders(args, "train")
    valid_loader = build_loaders(args, "val")
    trainer = Trainer(args=args, train_loader=train_loader, valid_loader=valid_loader)

    if args.phase == "train":
        trainer.train()
    else:
        trainer.test_epoch()
