import os
import time
import re
import random
import tqdm
import wandb
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from transformers import DataCollatorWithPadding, DataCollatorForLanguageModeling

from embodied_cd.common.print_utils import *
from embodied_cd.common.dataset_utils import PromptTemplate, VirtualHomeDataset
from embodied_cd.common.llm_utils import OpenAILLM
from embodied_cd.trl.algos.pipe import SentenceSimilarityPipeline
from embodied_cd.trl.algos.core import FixedKLController, AdaptiveKLController
from embodied_cd.trl.algos.core import (
    custom_collate,
    logprobs_from_logits, 
    entropy_from_logits, 
    clip_by_value, 
    whiten, 
    flatten_dict, 
    stack_dicts, 
    stats_to_np, 
    WANDB_PADDING,
)


class PlanTrainer:
    """
    Plan model (phase 2) value function trainer
    """

    default_params = {
        "total_epochs": 100,
        "lr": 2.82e-6,
        "batch_size": 4,
        "gamma": 0.99,
        "eta": 0.5,
        "loss_type": "mse",
    }

    def __init__(self, env_name, model, ref_model, tokenizer, ref_tokenizer, dataset, output_dir, device='cuda', **params):

        self.env_name = env_name
        self.output_dir = output_dir

        self.params = self.default_params
        self.params.update(params)

        self.model = model
        self.tokenizer = tokenizer
        self.ref_model = ref_model if ref_model is not None else model
        self.ref_tokenizer = ref_tokenizer if ref_tokenizer is not None else tokenizer
        self.device = self.model.device
        
        self.dataset = dataset
        self.dataset_len = len(dataset)

        self.dataloader = DataLoader(
            self.dataset,
            batch_size=self.params['batch_size'],
            shuffle=True,
            collate_fn=custom_collate,
            drop_last=True,
        )
        self.vf_template = PromptTemplate(env_name, 'cd-reward') 
        self.vf_optimizer = torch.optim.Adam(self.model.parameters(), lr=self.params['lr'])


    def train(self):
        for epoch in tqdm.tqdm(range(1, self.params['total_epochs']+1), desc="epoch"): # iterate over epochs
            all_stats = []
            for batch in self.dataloader:
                idxs = list(range(self.params['batch_size']))
                random.shuffle(idxs)
                
                batch_vf_loss = 0.
                # train the discriminator
                for i in range(self.params['batch_size']):
                    idx = idxs[i]
                        
                    """ Sample positive & negative actions """
                    # positive action
                    p_action = batch['actions'][idx]
                    # negative action sample 1 / from different trajectory
                    n_action_1 = batch['actions'][(idx + 1) % len(idxs)]
                    # negative action sample 2 / from arbitrary combination
                    action_format = PromptTemplate.load_env_action_format(self.env_name)
                    action_list = PromptTemplate.get_action_list(batch['states'][idx], action_format)
                    random.shuffle(action_list)
                    n_action_2 = action_list[0]

                    """ Calculate action value """
                    p_value = self.get_value(batch, idx, p_action)
                    n_value_1 = self.get_value(batch, idx, n_action_1)
                    n_value_2 = self.get_value(batch, idx, n_action_2)

                    # value loss
                    if self.params['loss_type'] == 'cross_entropy':
                        p_loss = -torch.log(p_value)
                        n_loss_1 = -torch.log(1 - n_value_1) 
                        n_loss_2 = -torch.log(1 - n_value_2)
                        vf_loss = p_loss + n_loss_1 / self.params['eta'] + n_loss_2
                        batch_vf_loss += vf_loss
                    elif self.params['loss_type'] == 'mse':
                        reward = batch['rewards'][idx]
                        p_loss = (p_value - reward)**2
                        n_loss_1 = (n_value_1 - 0.0)**2
                        n_loss_2 = (n_value_2 - 0.0)**2
                        vf_loss = p_loss + n_loss_1 / self.params['eta'] + n_loss_2
                        batch_vf_loss += vf_loss

                    # stats stacking
                    all_stats.append({
                        "vf_loss": vf_loss, 
                        "vf_p_loss": p_loss,
                        "vf_n_loss_1": n_loss_1,
                        "vf_n_loss_2": n_loss_2, 
                    })
                
                batch_vf_loss = batch_vf_loss / self.params['batch_size']
                self.vf_optimizer.zero_grad()
                batch_vf_loss.backward()
                self.vf_optimizer.step()

            # batch logging
            train_stats = stack_dicts(all_stats)
            stats = {}
            for k, v in train_stats.items():
                stats[f'{k}'] = torch.mean(v, axis=0)
            stats = stats_to_np(stats)
            print_warn(f"Epoch {epoch}: {stats}")            
            wandb.log(stats)

    def get_value(self, batch, idx, action):
        # positive sample
        example = self.vf_template(
            batch['instructions'][idx], batch['states'][idx], batch['thinks'][idx], action, batch['histories'][idx])
        query_id = self.tokenizer.apply_chat_template(
            VirtualHomeDataset._convert_to_chat(example), tokenize=True, return_tensors='pt').to(self.device)
        num_logits_to_keep = len(query_id[0]) - len(batch['query_ids'][idx][0])
        attention_mask = torch.cat([torch.ones(len(query_id[0])), torch.ones(num_logits_to_keep)]).unsqueeze(0).to(self.device)
        model_output = self.model(query_id, attention_mask=attention_mask, num_logits_to_keep=num_logits_to_keep, average_pool=True)
        value = model_output.values[:,-1:,:].squeeze()
        return value
    
    def save_pretrained(self, output_dir: str):
        self.model.save_pretrained(output_dir)
        self.tokenizer.save_pretrained(output_dir)
        torch.save(self.model.vl_head.state_dict(), output_dir+'/vl_head.pth')
