import os
import time
import re
import random
import tqdm
import wandb
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from transformers import DataCollatorWithPadding, DataCollatorForLanguageModeling

from embodied_cd.common.print_utils import *
from embodied_cd.common.dataset_utils import PromptTemplate, VirtualHomeDataset
from embodied_cd.common.llm_utils import OpenAILLM
from embodied_cd.trl.algos.pipe import SentenceSimilarityPipeline
from embodied_cd.trl.algos.core import FixedKLController, AdaptiveKLController
from embodied_cd.trl.algos.core import (
    custom_collate,
    logprobs_from_logits, 
    entropy_from_logits, 
    clip_by_value, 
    whiten, 
    flatten_dict, 
    stack_dicts, 
    stats_to_np, 
    WANDB_PADDING,
)
from embodied_cd.trl.models.value import ValueHeadWithLogit

LOG_PI_NORM_MAX = 10
LOG_PI_NORM_MIN = -30


class PlanWBCTrainer:
    """
    Plan model (phase 2) value function trainer
    """

    default_params = {
        "total_epochs": 100,
        "lr": 2.82e-6,
        "batch_size": 4,
        "gamma": 0.99,
        "eta": 0.5,
        "loss_type": "mse",
    }

    def __init__(self, env_name, model, ref_model, tokenizer, ref_tokenizer, dataset, output_dir, device='cuda', **params):

        self.env_name = env_name
        self.output_dir = output_dir

        self.params = self.default_params
        self.params.update(params)

        self.model = model
        self.value_head = ValueHeadWithLogit(
            self.model.config.hidden_size, pdrop=0.1, activation_fn='sigmoid', detach=True).to(self.model.device)
        self.tokenizer = tokenizer
        self.ref_model = ref_model if ref_model is not None else model
        self.ref_tokenizer = ref_tokenizer if ref_tokenizer is not None else tokenizer
        self.device = self.model.device
        
        self.dataset = dataset
        self.dataset_len = len(dataset)

        self.dataloader = DataLoader(
            self.dataset,
            batch_size=self.params['batch_size'],
            shuffle=True,
            collate_fn=custom_collate,
            drop_last=True,
        )
        self.vf_template = PromptTemplate(env_name, 'cd-reward') 
        self.pi_template = PromptTemplate(env_name, 'cd-action-think') 
        
        self.vf_optimizer = torch.optim.Adam(self.value_head.parameters(), lr=self.params['lr'])
        self.pi_optimizer = torch.optim.Adam(self.model.parameters(), lr=self.params['lr'])


    def train(self):
        for epoch in tqdm.tqdm(range(1, self.params['total_epochs']+1), desc="epoch"): # iterate over epochs
            all_stats = []
            for batch in self.dataloader:
                idxs = list(range(self.params['batch_size']))
                random.shuffle(idxs)
                
                batch_vf_loss, batch_pi_loss = 0., 0.
                # train the discriminator
                for i in range(self.params['batch_size']):
                    idx = idxs[i]
                        
                    """ Sample positive & negative actions """
                    # positive action
                    p_action = batch['actions'][idx]
                    # negative action sample 1 / from different trajectory
                    n_action_1 = batch['actions'][(idx + 1) % len(idxs)]
                    # negative action sample 2 / from arbitrary combination
                    action_format = PromptTemplate.load_env_action_format(self.env_name)
                    action_list = PromptTemplate.get_action_list(batch['states'][idx], action_format)
                    random.shuffle(action_list)
                    n_action_2 = action_list[0]
                
                    """ Cacluate action logprob """
                    p_logprob, p_bc_loss = self.get_plan_logprob(batch, idx, p_action)
                    n_logprob_1, n_bc_loss_1 = self.get_plan_logprob(batch, idx, n_action_1)
                    n_logprob_2, n_bc_loss_2 = self.get_plan_logprob(batch, idx, n_action_2)

                    """ Calculate action value """
                    p_value = self.get_value(batch, idx, p_action, p_logprob.detach())
                    n_value_1 = self.get_value(batch, idx, n_action_1, n_logprob_1.detach())
                    n_value_2 = self.get_value(batch, idx, n_action_2, n_logprob_2.detach())

                    # value loss
                    if self.params['loss_type'] == 'cross_entropy':
                        p_loss = -torch.log(p_value)
                        n_loss_1 = -torch.log(1 - n_value_1) 
                        n_loss_2 = -torch.log(1 - n_value_2)
                        vf_loss = p_loss + n_loss_1 / self.params['eta'] + n_loss_2
                        batch_vf_loss += vf_loss
                    elif self.params['loss_type'] == 'mse':
                        reward = batch['rewards'][idx]
                        p_loss = (p_value - reward)**2
                        n_loss_1 = (n_value_1 - 0.0)**2
                        n_loss_2 = (n_value_2 - 0.0)**2
                        vf_loss = p_loss + n_loss_1 / self.params['eta'] + n_loss_2
                        batch_vf_loss += vf_loss
                        
                    self.vf_optimizer.zero_grad()
                    vf_loss.backward()
                    self.vf_optimizer.step()

                    # policy loss
                    p_clip = torch.squeeze(p_value).detach()
                    n_clip_1 = torch.squeeze(n_value_1).detach()
                    n_clip_1[n_clip_1 < 0.5] = 0.0
                    n_clip_2 = torch.squeeze(n_value_2).detach()
                    n_clip_2[n_clip_2 < 0.5] = 0.0

                    p_corr_loss = p_bc_loss * (self.params['eta'] / (p_clip * (1.0 - p_clip)) + 1.0)
                    n_corr_loss_1 = n_bc_loss_1 * (1.0 / (1.0 - n_clip_1) - 1.0)
                    n_corr_loss_2 = n_bc_loss_2 * (1.0 / (1.0 - n_clip_2) - 1.0)
                    pi_loss = p_corr_loss + n_corr_loss_1 + n_corr_loss_2
                    batch_pi_loss += pi_loss

                    self.pi_optimizer.zero_grad()
                    pi_loss.backward()
                    self.pi_optimizer.step()

                    # stats stacking
                    all_stats.append({
                        "vf_loss": vf_loss, 
                        "vf_p_loss": p_loss,
                        "vf_n_loss_1": n_loss_1,
                        "vf_n_loss_2": n_loss_2, 
                        "pi_loss": pi_loss,
                        "p_corr_loss": p_corr_loss,
                        "n_corr_loss_1": n_corr_loss_1,
                        "n_corr_loss_2": n_corr_loss_2,
                    })
                
                """
                batch_vf_loss = batch_vf_loss / self.params['batch_size']
                self.vf_optimizer.zero_grad()
                batch_vf_loss.backward()
                self.vf_optimizer.step()

                batch_pi_loss = batch_pi_loss / self.params['batch_size']
                self.pi_optimizer.zero_grad()
                batch_pi_loss.backward()
                self.pi_optimizer.step()
                """

            # batch logging
            train_stats = stack_dicts(all_stats)
            stats = {}
            for k, v in train_stats.items():
                stats[f'{k}'] = torch.mean(v, axis=0)
            stats = stats_to_np(stats)
            print_warn(f"Epoch {epoch}: {stats}")            
            wandb.log(stats)

    
    def get_plan_logprob(self, batch, idx, action):
        example = self.pi_template(
            batch['instructions'][idx], batch['states'][idx], batch['thinks'][idx], action, batch['histories'][idx])
        query_tensor = self.tokenizer.apply_chat_template(
            VirtualHomeDataset._convert_to_chat(example), return_tensors='pt').to(self.device)
        response_len = len(batch['response_ids'][idx][0])
        labels = query_tensor.clone()
        labels[:, :-response_len] = -100
        model_output = self.model(query_tensor, labels=labels)
        logprobs = logprobs_from_logits(model_output.logits[:,:-1,:], query_tensor[:,1:])
        logprob = torch.sum(logprobs[:,-response_len:], dim=-1, keepdim=True)
        logprob_clip = torch.clip(logprob, LOG_PI_NORM_MIN, LOG_PI_NORM_MAX)
        logprob_norm = (logprob_clip - LOG_PI_NORM_MIN) / (LOG_PI_NORM_MAX - LOG_PI_NORM_MIN)
        return logprob, model_output.loss

    def get_value(self, batch, idx, action, logprob):
        # positive sample
        example = self.vf_template(
            batch['instructions'][idx], batch['states'][idx], batch['thinks'][idx], action, batch['histories'][idx])
        query_id = self.tokenizer.apply_chat_template(
            VirtualHomeDataset._convert_to_chat(example), tokenize=True, return_tensors='pt').to(self.device)
        num_logits_to_keep = len(batch['response_ids'][idx][0])
        attention_mask = torch.cat([torch.ones(len(query_id[0])), torch.ones(num_logits_to_keep)]).unsqueeze(0).to(self.device)
        with torch.no_grad():
            model_output = self.model(query_id, attention_mask=attention_mask, num_logits_to_keep=num_logits_to_keep, average_pool=True)
        value = self.value_head(model_output.hidden_states.squeeze(0), logprob)
        return value
    
    def save_pretrained(self, output_dir: str):
        self.model.save_pretrained(output_dir)
        self.tokenizer.save_pretrained(output_dir)
        torch.save(self.model.vl_head.state_dict(), output_dir+'/vl_head.pth')
