from comet_ml import Experiment
from comet_ml import OfflineExperiment
from model.baseline_model import BaselineModel
import torch
from utils.CIL_dataset import CILSetTask
from model.temporalShiftModule.ops.transforms import *
import argparse
import yaml, pickle
import torch.nn as nn
import os
import random
import torch.nn.functional as F
from utils.CILdatasetting import IncrementalDatasetBuilder

random.seed(10)
 
class TextualContrastiveLoss(nn.Module): 
    
    def __init__(self, batch_size, device, temperature=0.5):
        super().__init__()
        self.batch_size = batch_size
        self.device = device
        self.register_buffer("temperature", torch.tensor(temperature))
        self.register_buffer("negatives_mask", (~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool, device=device)).float())
    
    def forward(self, emb_i, emb_j, labels):
        z_i = F.normalize(emb_i, dim=1)
        z_j = F.normalize(emb_j, dim=1)

        representations = torch.cat([z_i, z_j], dim=0)
        similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)

        labels = labels.type(torch.FloatTensor)
        labels = labels.to(device)
        label_t = torch.cat([labels, labels], dim=0)
        ws = torch.eq(label_t.unsqueeze(1), label_t.unsqueeze(0))
        ws = ws*self.negatives_mask 

        nominator = torch.sum(ws * torch.exp(similarity_matrix / self.temperature), dim = 1) 
        denominator = self.negatives_mask * torch.exp(similarity_matrix / self.temperature) 

        loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))
        loss = torch.sum(loss_partial) / (2 * self.batch_size)
        return loss

class ContrastiveLoss(nn.Module):  
    
    def __init__(self, batch_size, device, temperature=0.5):
        super().__init__()
        self.batch_size = batch_size
        self.register_buffer("temperature", torch.tensor(temperature))
        self.register_buffer("negatives_mask", (~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool, device=device)).float())
          
    def forward(self, emb_i, emb_j):
        z_i = F.normalize(emb_i, dim=1)
        z_j = F.normalize(emb_j, dim=1)

        representations = torch.cat([z_i, z_j], dim=0)
        similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)

        sim_ij = torch.diag(similarity_matrix, self.batch_size)
        sim_ji = torch.diag(similarity_matrix, -self.batch_size)
        positives = torch.cat([sim_ij, sim_ji], dim=0)

        nominator = torch.exp(positives / self.temperature)
        denominator = self.negatives_mask * torch.exp(similarity_matrix / self.temperature)

        loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))
        loss = torch.sum(loss_partial) / (2 * self.batch_size)
        return loss

def parse_conf(conf, new_dict = {}): 
    for k, v in conf.items():
        if type(v) == dict:
            new_dict = parse_conf(v, new_dict)
        else:
            new_dict[k] = v
    return new_dict

def main():
    
    global dict_conf, device, experiment, memory_size, batch_size, is_activityNet
    
    parser = argparse.ArgumentParser(description="vCLIMB Model")
    parser.add_argument("-conf","--conf_path", default = '/data/whj/CIL_IN_VIDEO/PIVOT-main/conf/Kinetics.yaml')
    

    args = parser.parse_args()
    conf_file = open(args.conf_path, 'r')
    print("Conf file dir: ",conf_file)
    dict_conf = yaml.load(conf_file, Loader=yaml.Loader) 
    name_comet = dict_conf['comet']['name'] 
    dict_conf['comet']['name'] = name_comet.format(
        dict_conf['dataset']['name'], 
        dict_conf['feature_encoder']['type'], 
        dict_conf['feature_encoder']['num_segments'], 
        dict_conf['type_task'], 
        dict_conf['memory']['memory_size'], 
        dict_conf['type_loss'])
    
    path_memory = dict_conf['memory']['path_memory']
    dict_conf['memory']['path_memory'] = path_memory.format(
        dict_conf['dataset']['name'], 
        dict_conf['feature_encoder']['type'], 
        dict_conf['feature_encoder']['num_segments'], 
        dict_conf['type_task'], 
        dict_conf['memory']['memory_size'], 
        dict_conf['type_loss'])
    
    path_model = dict_conf['checkpoints']['path_model']
    dict_conf['checkpoints']['path_model'] = path_model.format(
        dict_conf['dataset']['name'], 
        dict_conf['feature_encoder']['type'], 
        dict_conf['feature_encoder']['num_segments'], 
        dict_conf['type_task'], 
        dict_conf['memory']['memory_size'], 
        dict_conf['type_loss'])

    
    api_key = dict_conf['comet']['api_key'] 
    workspace = dict_conf['comet']['workspace']
    project_name = dict_conf['comet']['project_name']

    experiment = OfflineExperiment(api_key=api_key,
                            project_name=project_name, workspace=workspace) 
    

    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") 
    
    model = BaselineModel(device, dict_conf, experiment)
    
    path_data = dict_conf['dataset']['path_data']
    path_frames = dict_conf['dataset']['path_frames']
    first_task = dict_conf['dataset']['first_task']
    num_tasks = dict_conf['dataset']['num_tasks']
    train_path = dict_conf['dataset']['train_path']
    test_path = dict_conf['dataset']['test_path']
    is_use_PIVOT = dict_conf['dataset']['is_use_PIVOT']
    is_use_val = dict_conf['dataset']['is_use_val']
    dataset_name = dict_conf['dataset']['name']
    if dataset_name == 'Kinetics':
        path_frames_train = dict_conf['dataset']['path_frames_train']
        path_frames_test = dict_conf['dataset']['path_frames_test']
        path_frames_val = dict_conf['dataset']['path_frames_val']

    if not is_use_PIVOT:
        builder = IncrementalDatasetBuilder(
            dataset_path=path_frames,
            train_label_file=train_path,
            test_label_file = test_path,
            n_tasks=num_tasks,                 
            first_task_class_num=first_task,    
            val_ratio=0,              
            random_seed=2025,
            random_order=True
        )
        data = builder.build()
    else:
        with open(path_data, 'rb') as handle: 
            data = pickle.load(handle)  

    num_class = len(data['train'][0].keys())
    
    is_activityNet = dict_conf['dataset']['is_activityNet'] if 'is_activityNet' in dict_conf['dataset'] else False

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    
    
    
    train_augmentation = model.get_augmentation(flip=False if 'something' in dataset_name or 'jester' in dataset_name else True)
    path_frames = dict_conf['dataset']['path_frames']
    memory_size = dict_conf['memory']['memory_size']
    batch_size = dict_conf['batch_size']
    num_workers = dict_conf['num_workers']
    arch = dict_conf['feature_encoder']['type_clip_model']
    num_segments = dict_conf['feature_encoder']['num_segments']
    path_memory = dict_conf['memory']['path_memory']

    is_use_half = dict_conf['dataset']['is_use_half']
    task_num = 0
    if is_use_half:
        task_num = dict_conf['setting']['task_num']

    normalize = GroupNormalize(input_mean, input_std)
    data_length = 1
    
    train_transforms = torchvision.transforms.Compose([
        train_augmentation,
        Stack(roll=(arch in ['BNInception', 'InceptionV3'])),
        ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])),
        normalize
    ])
    val_transforms = torchvision.transforms.Compose([
        GroupScale(int(scale_size)),
        GroupCenterCrop(crop_size),
        Stack(roll=(arch in ['BNInception', 'InceptionV3'])),
        ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])),
        normalize,
    ])
    
    train_per_noise = dict_conf['dataset']['train_per_noise'] if 'train_per_noise' in dict_conf['dataset'] else 0
    val_per_noise = dict_conf['dataset']['val_per_noise'] if 'val_per_noise' in dict_conf['dataset'] else 0
    co_threshold = dict_conf['dataset']['co_threshold'] if 'co_threshold' in dict_conf['dataset'] else 0

    if dataset_name == 'Kinetics':
        path_frames = path_frames_train
    train_cilDatasetList = CILSetTask(data['train'], path_frames, memory_size, batch_size, shuffle=True, 
                                      num_workers=num_workers, num_frame_to_save = dict_conf['num_frame_to_save'], 
                                      is_activityNet = is_activityNet, per_noise = train_per_noise, co_threshold = co_threshold, 
                                      drop_last=True, pin_memory=True, num_segments=num_segments, new_length=data_length, 
                                      modality='RGB',transform=train_transforms, dense_sample=False, train_enable = True, name_dataset=dataset_name)
    if is_use_val:
        if dataset_name == 'Kinetics':
            path_frames = path_frames_val
        val_cilDatasetList = CILSetTask(data['val'], path_frames, memory_size, batch_size, shuffle=False, 
                                        num_workers=num_workers, is_activityNet = is_activityNet, per_noise = val_per_noise, 
                                        co_threshold = co_threshold, pin_memory=True, num_frame_to_save = dict_conf['num_frame_to_save'], 
                                        num_segments=num_segments, new_length=data_length, modality='RGB', 
                                        transform=val_transforms, random_shift=False, dense_sample=False, train_enable = False, name_dataset=dataset_name)
    else:
        val_cilDatasetList = None
   

    test_cilDatasetList = None
    if not is_activityNet:
        if dataset_name == 'Kinetics':
            path_frames = path_frames_test
        test_cilDatasetList = CILSetTask(data['test'], path_frames, memory_size, batch_size, shuffle=False, 
                                        num_workers=num_workers, is_activityNet = is_activityNet, per_noise = val_per_noise,
                                        co_threshold = co_threshold, pin_memory=True, num_frame_to_save = dict_conf['num_frame_to_save'], 
                                        num_segments=num_segments, new_length=data_length, modality='RGB', 
                                        transform=val_transforms, random_shift=False, dense_sample=False, train_enable = False, name_dataset=dataset_name)



    cls_loss = nn.CrossEntropyLoss().to(device)
    textual_con_loss = TextualContrastiveLoss(dict_conf['batch_size'], device, temperature = dict_conf['temperature']).to(device) 
    
    model.set_losses(cls_loss, textual_con_loss)

    is_use_adapter = dict_conf['Adapter']['is_use_adapter']

    if dict_conf['checkpoints']['train_mode']:
        model.add_num_classes(num_class) 
        model.create_fc()
        train_loop(model, train_cilDatasetList, val_cilDatasetList, test_cilDatasetList, is_use_adapter=is_use_adapter)

    
        
def train_loop(model, train_cilDatasetList, val_cilDatasetList, test_cilDatasetList, is_use_adapter=False):
    iter_trainDataloader = iter(train_cilDatasetList) 
    num_tasks = train_cilDatasetList.num_tasks 
    path_memory = dict_conf['memory']['path_memory'] 
    classes_mem = None

    for j in range(num_tasks):
        classes, data, train_loader_i, len_data, num_next_classes = next(iter_trainDataloader)
        if memory_size != 'ALL':  
            m = memory_size // model.num_classes  
        else:
            m = 'ALL'
        
        old_memory = model.memory
        model.add_samples_to_mem(data, m) 
        train_loader_i_mem = None
        if m > 0 or m == 'ALL':
            data_mem = model.memory  
            if model.num_training_phases > 1: 
                data_mem = {**old_memory, **data} 
            train_loader_i_mem = train_cilDatasetList.get_dataloader(data_mem, batch_size, None, False, False)
        
        if train_loader_i_mem is not None:
            classes = train_loader_i_mem.dataset.classes if model.num_training_phases > 1 else train_loader_i.dataset.classes
        else:
            if classes_mem is None:
                classes_mem = list(train_loader_i.dataset.classes)
            else:
                for value in list(train_loader_i.dataset.classes):
                    classes_mem.append(value)
            classes = list(dict.fromkeys(classes_mem))
        if model.type_mod == 'None' and model.type_cls != 'Linear' and not model.enable_temporal_module and not is_use_adapter:
            print('init validation')

            if val_cilDatasetList is not None:
                model.validate(val_cilDatasetList, j, classes, 1, type_val = 'val', is_final = True)
            if not is_activityNet:
                with experiment.test():
                    total_acc_test,_ = model.validate(test_cilDatasetList, j, classes, 1, type_val = 'test', is_final = True)

                    experiment.log_metric("Acc_task_{}".format(j+1), total_acc_test)
                    print('Test Accuracy: %d %%' % total_acc_test)  
                    
        else:
            print('init training')
            model.train_task(j, train_loader_i, train_loader_i_mem, val_cilDatasetList, num_tasks=num_tasks)
            if not is_activityNet:
                with experiment.test():
                    modulate_vid = True if model.type_mod == 'Prompt' else False
                    total_acc_test = model.validate(test_cilDatasetList, j, classes, model.num_training_phases, type_val = 'test', is_final = True, modulate_vid = modulate_vid, modulate_txt = False, end_train=True)
                    total_acc_test = total_acc_test[0] if type(total_acc_test) == tuple else total_acc_test
                    experiment.log_metric("Acc_task_{}".format(j+1), total_acc_test)
                    print('Test Accuracy: %d %%' % total_acc_test) 

        if model.num_training_phases == 1:
            train_cilDatasetList.memory = model.memory
        else:
            empty_memory = {}
            train_cilDatasetList.memory = empty_memory
        print('n_known_classes: ',len(model.memory))
        

        model.prepare_for_next_classes(num_next_classes)
        

if __name__ == '__main__':
    main()