from typing import Optional, Tuple
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

class SimpleMlpEncoder(nn.Module):
    def __init__(
        self,
        input_dim,
        hidden_dims,
        output_dim,
        activation = nn.ReLU
    ) -> None:
        super().__init__()
        hidden_dims = [input_dim] + list(hidden_dims)
        model = []
        for in_dim, out_dim in zip(hidden_dims[:-1], hidden_dims[1:]):
            model += [nn.Linear(in_dim, out_dim), activation()]
        self.model = nn.Sequential(*model)
        self.out=nn.Linear(hidden_dims[-1], output_dim)
        self.tanh=nn.Tanh()

    def forward(self, x: torch.Tensor, cost: torch.Tensor) -> torch.Tensor:
        cost = cost.reshape(-1,1).to(torch.float32)
        b,s = x.shape
        x=torch.cat([x,cost],dim=-1)
        x=self.model(x)
        x=self.tanh(self.out(x)).mean(0)
        return x


class SafetyAwareEncoder(nn.Module):
    def __init__(
        self,
        input_dim,
        hidden_dims,
        output_dim,
        activation = nn.ReLU,
        simple_gate = True
    ) -> None:
        super().__init__()
        hidden_dims = [input_dim] + list(hidden_dims)
        safe_model = []
        unsafe_model = []
        for in_dim, out_dim in zip(hidden_dims[:-1], hidden_dims[1:]):
            safe_model += [nn.Linear(in_dim, out_dim), activation()]
            unsafe_model += [nn.Linear(in_dim, out_dim), activation()]
        self.safe_model = nn.Sequential(*safe_model)
        self.unsafe_model = nn.Sequential(*unsafe_model)
        self.out=nn.Linear(hidden_dims[-1], output_dim)
        self.simple_gate = simple_gate
        #simple gate
        self.gate=nn.Linear(hidden_dims[-1], 1)
        
        #self attention 首先在safe patch中利用simple gate（attention）得到query，
        self.query_gate=nn.Linear(hidden_dims[-1],1)
        self.query=nn.Linear(hidden_dims[-1],hidden_dims[-1])
        self.key=nn.Linear(hidden_dims[-1],hidden_dims[-1])
        self.value=nn.Linear(hidden_dims[-1],hidden_dims[-1])
        
        self.tanh=nn.Tanh()
        self.softmax=nn.Softmax(dim=0)

    def forward(self, x: torch.Tensor, cost: torch.Tensor) -> torch.Tensor:
        cost = cost.reshape(-1)
        b,s = x.shape
        safe_condition = (cost==0)
        unsafe_condition = cost>0
        safe_patch = x[safe_condition,:]
        unsafe_patch = x[unsafe_condition,:]
        safe_feat = self.safe_model(safe_patch)
        unsafe_feat = self.unsafe_model(unsafe_patch)
        if safe_patch.shape[0]==0:
            total_feat = unsafe_feat
        elif unsafe_patch.shape[0]==0:
            total_feat = safe_feat
        else:
            total_feat = torch.cat([safe_feat, unsafe_feat], dim=0)
        if self.simple_gate:
            # key = self.gate(total_feat)
            # scores = self.softmax(key)
            # res = scores.mul(total_feat).sum(0)
            return self.tanh(self.out(total_feat)).mean(0)
        else:
            assert safe_patch.shape[0]!=0
            query_key = self.query_gate(safe_feat)
            query_scores = self.softmax(query_key)
            query_res = query_scores.mul(safe_feat).sum(0)
            # [1,hiddden_dims[-1]]
            q = self.query(query_res)
            # [bs,hidden_dims[-1]]
            k = self.key(total_feat)
            v = self.value(total_feat)
            d_k = k.size(1)
            #[1,bs]
            attn_scores = torch.matmul(q, k.transpose(0,1))/d_k**0.5
            attn_weight = F.softmax(attn_scores, dim=-1)
            attn_res = torch.matmul(attn_weight, v)
            return self.tanh(self.out(attn_res))
    def vis_feat(self, x):
        safe_feat = self.safe_model(x)
        unsafe_feat = self.unsafe_model(x)
        return safe_feat, unsafe_feat

class MultiHeadDecoder(nn.Module):
    def __init__(
        self,
        input_dim,
        feature_hidden_dims,
        decoder_hidden_dims,
        state_dim,
        activation = nn.ReLU,
    ) -> None:
        super().__init__()
        feature_hidden_dims = [input_dim] + list(feature_hidden_dims)
        feature_extractor = []
        sigmoid = nn.Sigmoid()
        for in_dim, out_dim in zip(feature_hidden_dims[:-1], feature_hidden_dims[1:]):
            feature_extractor += [nn.Linear(in_dim, out_dim), activation()]
        decoder_hidden_dims = [feature_hidden_dims[-1]] + list(decoder_hidden_dims)
        state_decoder = []
        reward_decoder = []
        cost_decoder = []
        for in_dim, out_dim in zip(decoder_hidden_dims[:-1], decoder_hidden_dims[1:]):
            state_decoder += [nn.Linear(in_dim, out_dim), activation()]
            reward_decoder += [nn.Linear(in_dim, out_dim), activation()]
            cost_decoder += [nn.Linear(in_dim, out_dim), activation()]
        state_decoder += [nn.Linear(decoder_hidden_dims[-1],state_dim)]
        reward_decoder += [nn.Linear(decoder_hidden_dims[-1],1)]
        cost_decoder += [nn.Linear(decoder_hidden_dims[-1],1)]
        cost_decoder += [sigmoid]
        self.feature_extractor = nn.Sequential(*feature_extractor)
        self.state_decoder = nn.Sequential(*state_decoder)
        self.reward_decoder = nn.Sequential(*reward_decoder)
        self.cost_decoder = nn.Sequential(*cost_decoder)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        feature = self.feature_extractor(x)
        state_pred = self.state_decoder(feature)
        reward_pred = self.reward_decoder(feature)
        cost_pred = self.cost_decoder(feature)
        return state_pred, reward_pred, cost_pred


class ContextEncoderTrainer:
    def __init__(self,
                 encoder,
                 decoder,
                 logger,
                 learning_rate: float = 1e-4,
                 betas: Tuple[float, ...] = (0.9, 0.999),
                 decay_step: int = 2000,
                 decay_rate: float = 0.9,
                 min_learning_rate: float = 3e-5,
                 state_loss_weight: float = 1.0,
                 reward_loss_weight: float = 0.0,
                 cost_loss_weight: float = 2.0,
                 device="cpu",
                ):
        self.encoder=encoder
        self.decoder=decoder
        self.optim = torch.optim.AdamW(
            [{'params': encoder.parameters()},{'params': decoder.parameters()}], 
            lr=learning_rate, 
            eps=1e-5,
            betas=betas
        )
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optim, step_size=decay_step, gamma=decay_rate)
        self.min_lr = min_learning_rate
        self.logger = logger
        self.state_loss_weight = state_loss_weight
        self.reward_loss_weight = reward_loss_weight
        self.cost_loss_weight = cost_loss_weight
        self.device = device

    #given the context, train one step
    def train_one_step(self, state_ls, action_ls, next_state_ls, reward_ls, cost_ls):
        meta_batch_size = len(state_ls)
        encodings=[]
        for i in range(meta_batch_size):
            state=state_ls[i]
            action=action_ls[i]
            next_state=next_state_ls[i]
            reward=reward_ls[i]
            cost=cost_ls[i]
            encoder_input=torch.cat([state,action,next_state,reward],dim=-1)
            encoding=self.encoder(encoder_input,cost)
            encodings.append(encoding)
        #[meta_bs,encoding_size]
        encodings=torch.stack(encodings,dim=0)
        #[meta_bs,bs,state_size]
        state_ls=torch.stack(state_ls,dim=0)
        action_ls=torch.stack(action_ls,dim=0)
        next_state_ls=torch.stack(next_state_ls,dim=0)
        reward_ls=torch.stack(reward_ls,dim=0)
        cost_ls=torch.stack(cost_ls,dim=0)
        _,bs,_=state_ls.shape
        encodings=encodings.unsqueeze(1).expand(-1,bs,-1)
        decoder_input=torch.cat([encodings,state_ls,action_ls],dim=-1)
        next_state_pred, reward_pred, cost_pred=self.decoder(decoder_input)
        mse_loss=nn.MSELoss()
        bce_loss=nn.BCELoss()
        state_pred_loss=mse_loss(next_state_pred,next_state_ls.detach())
        reward_pred_loss=mse_loss(reward_pred,reward_ls.detach())
        cost_pred_loss=bce_loss(cost_pred,cost_ls.detach())
        total_loss=self.state_loss_weight*state_pred_loss+self.reward_loss_weight*reward_pred_loss+self.cost_loss_weight*cost_pred_loss
        self.optim.zero_grad()
        total_loss.backward()
        self.optim.step()
        if self.scheduler.get_last_lr()[0]>self.min_lr:
            self.scheduler.step()
        with torch.no_grad():
            cost_pred_bn=(cost_pred>0.5).to(torch.float32)
            acc=torch.sum(cost_pred_bn==cost_ls)/torch.sum(cost_ls==cost_ls)
        self.logger.store(
            tab="train",
            all_loss=total_loss.item(),
            next_state_loss=state_pred_loss.item(),
            reward_loss=reward_pred_loss.item(),
            cost_loss=cost_pred_loss.item(),
            cost_acc=acc.item(),
            train_lr=self.scheduler.get_last_lr()[0],
        )
        
                    
    
    def evaluate(self, state_ls, action_ls, next_state_ls, reward_ls, cost_ls):
        pass
    
    def vis_sample_embeddings(self, contexts, costs, task_ids, num_tasks, save_path):
        self.encoder.eval()
        vis_num = len(task_ids)/num_tasks
        x = []
        for i in range(len(task_ids)):
            encoding = self.encoder(contexts[i], costs[i]).cpu().detach().numpy()
            x.append(encoding)
        
        
        tsne = TSNE(n_components=2, init='pca', random_state=0)
        X_tsne = tsne.fit_transform(np.asarray(x))

        x_min, x_max = np.min(X_tsne, 0), np.max(X_tsne, 0)
        data = (X_tsne - x_min) / (x_max - x_min)

        colors = plt.cm.rainbow(np.linspace(0,1,num_tasks))
        #print(colors)
        
        plt.cla()
        fig = plt.figure()
        ax = plt.subplot(111)
        for i in range(num_tasks):
            plt.scatter(data[int(i*vis_num):int((i+1)*vis_num), 0], data[int(i*vis_num):int((i+1)*vis_num), 1],
                color=colors[i],label=str(i))
        plt.xticks([])
        plt.yticks([])
        plt.legend(bbox_to_anchor=(0, 1.02, 1, 0.102), 
                    loc='lower left',ncol=6, mode="expand", borderaxespad=0)
        plt.savefig(save_path)
        self.encoder.train()