import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm

import wandb


# 学習するニューラルネットワークモデルを定義
# 今回はシンプルな4層の全結合層からなるモデルを定義
class ScoreBaseModel(nn.Module):
    def __init__(self, input_dim, mid_dim=64):
        super(ScoreBaseModel, self).__init__()
        self.input_dim = input_dim
        self.mid_dim = mid_dim
        self.fc1 = nn.Linear(input_dim+1, mid_dim) # dtype=torch.float32)
        self.fc2 = nn.Linear(mid_dim, mid_dim//2) # dtype=torch.float32)
        self.fc3 = nn.Linear(mid_dim//2, mid_dim//4) # dtype=torch.float32)
        self.fc4 = nn.Linear(mid_dim//4, input_dim) # dtype=torch.float32)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

         # Move the model parameters to the device
        self.to(self.device)

        return

    def forward(self, x):
        x = F.tanh(self.fc1(x.float().to(self.device)))
        x = F.tanh(self.fc2(x))
        x = F.tanh(self.fc3(x))
        x = self.fc4(x)
        return x

def pretrain_sbm(batch_size, num_epoch, T, dist, model, optimizer, criterion, device, log_dir='logs'):
    wandb.init(project='mixture_model')
    np.random.seed(1234)
    torch.manual_seed(1234)
    # EpochごとのLossを保存するリスト
    out_list = []
    # モデルを学習モードに変更
    model.train()
    # 以下num_epoch回数分の学習を実行
    for epoch in tqdm(range(num_epoch), leave=False):

        # 確率分布からBatch Size分だけサンプリング
        sample = dist.sample((batch_size,))
        t = torch.full((batch_size,), np.random.rand() *T).to(device)
        t = t.unsqueeze(1)
        t_expand = t.expand(-1, 2)
        # サンプリング結果に付加するノイズを生成
        noise = torch.randn_like(sample).to(device)
        # サンプリング結果にノイズ付加
        # alpha_bar_t = exp(-2t)
        # print("t: ", t_expand)
        # print("sample: ", sample)
        # print("noise: ", noise)
        xt = torch.exp(-t_expand) * sample + torch.sqrt(1-torch.exp(-2*t_expand)) * noise

        # ノイズ付加後のサンプリング結果とtをConcatで結合
        x = torch.cat([xt, t], axis=1).to(device)
        ### 古いとconcatはない ###
        # x = torch.concat([xt, t], axis=1).to(device)
        # それらをモデルに入力して予測を実行
        pred_y = model(x)

        # 求めたいもの
        true_y = - noise / torch.sqrt(1-torch.exp(-2*t_expand))

        # 勾配情報を初期化
        optimizer.zero_grad()
        # Lossの計算
        loss = criterion(true_y, pred_y)
        # 誤差逆伝播法で勾配計算
        loss.backward()
        # 計算された勾配に基づいてモデルパラメータの更新
        optimizer.step()

        # EpochごとのLossをリストに追加
        out_list.append([epoch, loss.item()])
        # wandbにLossを保存
        wandb.log({'loss': loss.item()})

        # print(f"Epoch [{epoch+1}/{num_epoch}], Loss: {loss.item()}")

    print("Finished Pretraining")

    df_res = pd.DataFrame(out_list, columns=['epoch','loss'])
    return df_res