import torch
import torch.nn as nn
import torch.nn.functional as F


class EMA:
    def __init__(self, beta):
        super().__init__()
        self.beta = beta
        self.step = 0

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

    def step_ema(self, ema_model, model, step_start_ema=2000):
        if self.step < step_start_ema:
            self.reset_parameters(ema_model, model)
            self.step += 1
            return
        self.update_model_average(ema_model, model)
        self.step += 1

    def reset_parameters(self, ema_model, model):
        ema_model.load_state_dict(model.state_dict())


class SelfAttention(nn.Module):
    def __init__(self, channels, size):
        super(SelfAttention, self).__init__()
        self.channels = channels
        self.size = size
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)

class Model_conditional(nn.Module):
    def __init__(self, time_dim=256, num_obj = 1, device="cuda", save_path=None, include_uncond_flag=False):
        super().__init__()
        self.device = device
        self.time_dim = time_dim
        self.num_obj = num_obj
        self.save_path = save_path
        self.include_uncond_flag = include_uncond_flag
        input_dim = time_dim + num_obj+1
        if include_uncond_flag:
            input_dim += 1
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, time_dim),
        )

    def pos_encoding(self, t, channels=1):
        #inv_freq = 1.0 / (
        #    10000
        #    ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
        #)
        #pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        #pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        #pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        pos_enc = torch.sin(t/torch.tensor(10000.0).to(self.device))
        return pos_enc


    def forward(self, x, t, y=0., cond_flag=True):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.time_dim)
        if  self.num_obj == 0:
            output = self.mlp(torch.cat([x,t], dim=-1))
        else:
            if self.include_uncond_flag:
                if cond_flag:
                    output = self.mlp(torch.cat([x,t, y.type(torch.float), torch.ones_like(t)], dim=-1))
                else:
                    output = self.mlp(torch.cat([x,t, y.type(torch.float), torch.zeros_like(t)], dim=-1))
            else:
                output = self.mlp(torch.cat([x,t, y.type(torch.float), ],  dim=-1))
        return output

class Model_unconditional(Model_conditional):
    def __init__(self, time_dim=256, device="cuda"):
        super().__init__(time_dim, 0, device)

class Model_unconditional_p(nn.Module):
    def __init__(self, dim=256, device="cuda", save_path=None):
        super().__init__()
        self.device = device
        self.dim = dim
        self.save_path = save_path
        self.mlp = nn.Sequential(
            nn.Linear(dim, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, dim),
        )
        self.time_embed = nn.Linear(1, dim)

    def pos_encoding(self, t):
        pos_enc = torch.sin(t/torch.tensor(10000.0).to(self.device))
        pos_enc = self.time_embed(pos_enc)
        return pos_enc


    def forward(self, x, t):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t)
        output = self.mlp(x + t)
        return output
    
class Preference_model(nn.Module):
    def __init__(self, input_dim=256, device="cuda", save_path=None):
        super().__init__()
        self.device = device
        self.input_dim = input_dim
        self.save_path = save_path
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 24))
        self.time_embed = nn.Linear(1, 24)
        self.preference = nn.Sequential(
            nn.Linear(24, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )
        
    def forward(self, x_1, x_2, t):
        em = torch.stack([x_1,x_2], axis=1)
        #em_1 = self.mlp(x_1)
        #em_2 = self.mlp(x_2)
        #em = em_1 + em_2
        em = self.mlp(em)
        #t = t.unsqueeze(-1).type(torch.float)
        t = self.time_embed(t)
        em = em + t
        output = self.preference(em).squeeze(-1)
        return output

def save_model(model, save_path, device="cuda"):
    torch.save(model.to("cpu").state_dict(), save_path)
    model.to(device)

def load_model(model, save_path, device="cuda"):
    model.load_state_dict(torch.load(save_path))
    model.to(device)
    print(f"Successfully load trained model from {save_path}")