import os
import time
import re
import random
import tqdm
import wandb
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from embodied_cd.common.print_utils import *
from embodied_cd.common.dataset_utils import PromptTemplate
from embodied_cd.trl.algos.core import (
    custom_collate, 
    stack_dicts,
    stats_to_np,
)


class RewardTrainer:
    default_params = {
        "total_epochs": 100,
        "lr": 2.82e-6,
        "batch_size": 4,
        "temperature": 0.05,
    }

    def __init__(self, model, tokenizer, dataset, output_dir, **params):
        self.params = self.default_params
        self.params.update(params)
        
        self.model = model
        self.tokenizer = tokenizer
        self.dataset = dataset
        self.dataset_len = len(dataset)
        
        self.dataloader = DataLoader(
            self.dataset,
            batch_size=self.params['batch_size'],
            shuffle=True,
            collate_fn=custom_collate,
            drop_last=True,
        )

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.params['lr'])
        self.device = self.model.device
    
    def train(self):
        for epoch in tqdm.tqdm(range(1, self.params['total_epochs']+1), desc="epoch"): # iterate over epochs
            all_stats = []
            for batch in self.dataloader: # iterate over dataset
                ####################################### 
                think1, think2 = batch['thinks'], batch['thinks_copy']

                think1_list, think2_list = [] ,[]
                for _think1, _think2 in zip(think1, think2):
                    think1_list.extend(_think1.split('. '))
                    think2_list.extend(_think2.split('. '))
                think1_embeddings = self.encode(think1_list)
                think2_embeddings = self.encode(think2_list)

                loss = self.compute_loss(think1_embeddings, think2_embeddings)
                # stats stacking
                all_stats.append({"loss": loss})

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                ####################################### 

            train_stats = stack_dicts(all_stats)
            stats = {}
            for k, v in train_stats.items():
                stats[f'{k}'] = torch.mean(v, axis=0)
            stats = stats_to_np(stats)
            print_warn(f"Epoch {epoch}: {stats}")            
            wandb.log(stats)

    def encode(self, sentences: list):
        input_ids = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(self.device)
        model_output = self.model(**input_ids)
        # mean pooling
        token_embeddigns = model_output[0]
        input_mask_expanded = input_ids['attention_mask'].unsqueeze(-1).expand(token_embeddigns.size()).float()
        sentence_embeddings = torch.sum(token_embeddigns * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sentence_embeddings
    
    def compute_loss(self, think1_embeddings, think2_embeddings):
        batch_size = think1_embeddings.shape[0]
        embeddings = torch.cat([think1_embeddings, think2_embeddings], dim=0)
        sim_matrix = torch.matmul(embeddings, embeddings.T) / self.params['temperature']
        # masking for diagonal
        mask = torch.eye(2 * batch_size, dtype=torch.bool).to(self.device)
        sim_matrix = sim_matrix.masked_fill(mask, -9e15)
        # cross entropy
        positives = torch.cat(
            [torch.arange(batch_size, 2 * batch_size), torch.arange(0, batch_size)], dim=0).to(self.device)
        loss = F.cross_entropy(sim_matrix, positives)
        return loss / batch_size

    def save_pretrained(self, output_dir: str):
        self.model.save_pretrained(output_dir)
        self.tokenizer.save_pretrained(output_dir)
