#%%
import copy
import random
from tqdm import tqdm
import numpy as np

import wandb

from modules.info_nce import InfoNCE
# %%
def train_function(
        textual_data,
        tokenizer,
        model, 
        optimizer, 
        config, 
        device):
    
    """model specification"""
    if config["language_model"] in ["bert-base", "bert-large", "roberta"]:
        tune_layers = [
            model.encoder.layer[-i] for i in range(1, config["layers"] + 1)
        ]
    elif config["language_model"] in ["gpt2", "gpt-neo"]:
        tune_layers = [
            model.h[-i] for i in range(1, config["layers"] + 1)
        ]
        
    elif config["language_model"] in ["llama"]:
        tune_layers = [
            model.layers[-i] for i in range(1, config["layers"] + 1)
        ]
    elif config["language_model"] in ["opt"]:
        tune_layers = [
            model.decoder.layers[-i] for i in range(1, config["layers"] + 1)
        ]
        
    for param in model.parameters():
            param.requires_grad = False
    
    for layer in tune_layers: 
        for param in layer.parameters():
            param.requires_grad = True
    
    """train"""
    n = len(textual_data[0]) 

    for epoch in range(config["epochs"]):
        logs = {
            'info_loss': [], 
        }        
        for i in tqdm(range(0, n, config["batch_size"]), desc="inner loop..."):
            info_loss_ = []
            
            batch_texts = textual_data[i:min(i+config["batch_size"], n)]
            batch_size_ = len(batch_texts)            
            num_negatives = int(batch_size_*config["negative_rate"])

            positive_texts = []
            for j in range(batch_size_): # last batch is not equal to length of batch_texts
                positive = batch_texts[j]
                positive = parser(positive)
                if config["language_model"] in ["bert-base", "bert-large", "roberta"]: 
                    positive = remask(positive, n_mask=config["num_remask"])
                else:
                    positive = remask_unknown(positive, n_mask=config["num_remask"])
                    
                positive = reformat(positive)
                positive_texts.append(positive) 
            
            negative_texts = []
            for k in range(batch_size_):
                all_indices = list(range(batch_size_))
                all_indices.remove(k)
                    
                if len(all_indices) >= num_negatives:
                    negative_indices = random.sample(all_indices, num_negatives)
                else:
                    negative_indices = all_indices  # Handle case where not enough negatives are available 
                
                for idx in negative_indices:
                    negative_texts.append(batch_texts[idx])
            
            # If negative_texts is empty, skip this batch
            if len(negative_texts) == 0:
                print(f"Skipping batch {i} due to empty negative_texts")
                continue
            
            # assert batch_size_*num_negatives == len(negative_texts) 

            batch_ = tokenizer(
                batch_texts, 
                return_tensors='pt', 
                padding=True, 
                truncation=True, 
                max_length=512
            )

            positive_ = tokenizer(
                positive_texts, 
                return_tensors='pt', 
                padding=True, 
                truncation=True, 
                max_length=512
            )

            negative_ = tokenizer(
                negative_texts, 
                return_tensors='pt', 
                padding=True, 
                truncation=True, 
                max_length=512
            )
            
            outputs = model(
                **{key: value.to(device) for key, value in batch_.items()}
            )
            # anchor = outputs.last_hidden_state[:, 0, :]
            if config["language_model"] in ["bert-base", "bert-large", "roberta"]:
                anchor = outputs.last_hidden_state[:, 0, :]
            else:
                anchor = outputs.last_hidden_state.mean(dim=1)
            

            positive_outputs = model(
                **{k: v.to(device) for k, v in positive_.items()}
            )
            # positive = positive_outputs.last_hidden_state[:, 0, :]
            if config["language_model"] in ["bert-base", "bert-large", "roberta"]:
                positive = positive_outputs.last_hidden_state[:, 0, :]
            else:
                positive = positive_outputs.last_hidden_state.mean(dim=1)
                
                
            negative_outputs = model(
                **{k: v.to(device) for k, v in negative_.items()}
            )
            if config["language_model"] in ["bert-base", "bert-large", "roberta"]:
                negative = negative_outputs.last_hidden_state[:, 0, :]
            else:
                negative = negative_outputs.last_hidden_state.mean(dim=1)

            optimizer.zero_grad()

            info_nce = InfoNCE().to(device)
            info_loss = info_nce(anchor, positive, negative)
            info_loss.backward()
            optimizer.step()

            info_loss_.append(('info_loss', info_loss))    
            
            """accumulate losses"""
            for x, y in info_loss_:
                logs[x] = logs.get(x) + [y.item()]         

        print_input = "[epoch {:03d}]".format(epoch + 1)
        print_input += ''.join([', {}: {:.4f}'.format(x, np.mean(y)) for x, y in logs.items()])
        print(print_input)

        """update log"""
        wandb.log({x : np.mean(y) for x, y in logs.items()})
    
    return 
# %%
def parser(textual):
    items = textual.split(', ')
    data = {}
    for item in items:
        key, value = item.split(' is ')
        try:
            # 숫자로 변환 시도
            data[key] = float(value)
        except ValueError:
            data[key] = value
    return data

def remask(data, n_mask=2):
    cols = list(data.keys())
    remasked_cols = random.sample(cols, n_mask)
    remasked = copy.deepcopy(data)

    for col in remasked_cols:
        if remasked[col] == '[MASK]':
            continue  # skip existing [MASK] value 
        
        remasked[col] = '[MASK]'
        
    return remasked

def remask_unknown(data, n_mask=2):
    cols = list(data.keys())
    remasked_cols = random.sample(cols, n_mask)
    remasked = copy.deepcopy(data)
    
    for col in remasked_cols:
        if remasked[col] == '[UNK]':
            continue  # skip existing [UNK] value 
        
        remasked[col] = '[UNK]'
        
    return remasked

def reformat(data):
    # Format dictionary back into string format
    return ', '.join(f"{key} is {value}" for key, value in data.items())
