import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
import datetime
import json
from torch.utils.data import DataLoader, TensorDataset
import gc

import sys
sys.path.append("/home/***/work/doob")

from src.models.nn_potential import ApproxPotential
from src.objective.obj import objective_like_dpo
from src.models.init_models import xavier_init
from src.utils.set_seed import set_seed
from src.utils.sampling import doob_langevin_monte_carlo_modified, new_sbm_based_langevin_monte_carlo

class Potential_for_inner_loop:
    def __init__(self, beta = 1, beta_prime = 1, device = "cuda"):
        torch.autograd.set_detect_anomaly(True)
        self.device = device
        # self.idx_loop = idx_loop # 平均化されたポテンシャルを学習する
        self.model_potential = ApproxPotential(2).to(device) # 新しく学習するポテンシャル関数
        xavier_init(self.model_potential)
        # self.previousPotential = previousPotential # 既に学習済みのポテンシャル関数
        self.beta = beta
        self.beta_prime = beta_prime
        self.y_min = 0.0
        self.y_max = 2.0
        self._load_config()

    def _load_config(self):
        with open('configs/train_potential.json') as f:
            config_potential = json.load(f)

        with open('configs/regularization.json') as f:
            config_reg = json.load(f)

        self.num_samples = config_potential["num_samples"] # 10000
        self.lr = config_potential["lr"] # 0.0001
        self.num_epochs = config_potential["num_epochs"] # 1000
        self.y_min = config_reg["y_min"] # 0.0
        self.y_max = config_reg["y_max"] # 2.0
        self.gamma = config_reg["gamma"]

    def _prepare_data(self, objective_class, previousPotential, model_sbm, idx_loop, savedirname, plot_data):
        device = self.device
        num_samples = self.num_samples
        beta = self.beta
        beta_prime = self.beta_prime

        previousPotential.eval()
        model_sbm.eval()
        
        # グリッド状にサンプルを取得
        # Step 1: Define the range for x and y
        x_range = torch.linspace(-5, 5, steps= int(np.sqrt(num_samples))).to(device)  # 10 steps from -5 to 5
        y_range = torch.linspace(-5, 5, steps= int(np.sqrt(num_samples))).to(device)  # 10 steps from -5 to 5

        # Step 2: Create a meshgrid for x and y
        x_grid, y_grid = torch.meshgrid(x_range, y_range)
        # Convert to device if needed
        x_grid, y_grid = x_grid.to(device), y_grid.to(device)

        # Step 3: Combine x and y into a single tensor of shape (N, 2), where N is the number of points
        self.x_train = torch.stack([x_grid.flatten(), y_grid.flatten()], dim=-1).to(device)

        del x_range, y_range, x_grid, y_grid
        gc.collect()
        torch.cuda.empty_cache()

        assert self.x_train.shape[1] == 2, "x_train should have 2 columns"

        # self.x_train = torch.rand(num_samples, 2).to(device) * 10 - 5

        with torch.no_grad():
            dF_dp = objective_class.potential(self.x_train, previousPotential, model_sbm).unsqueeze(1).detach().clone()

        ## dF_dpがy_maxを超える場合はy_maxにする
        if idx_loop == 1:
            y_max_for_1 = 1
            max = torch.tensor(self.y_max * (beta + beta_prime) * y_max_for_1, dtype=dF_dp.dtype, device=dF_dp.device)
            min = torch.tensor(self.y_min * (beta + beta_prime) * y_max_for_1, dtype=dF_dp.dtype, device=dF_dp.device)
        else:
            max = torch.tensor(self.y_max * (beta + beta_prime), dtype=dF_dp.dtype, device=dF_dp.device)
            min = torch.tensor(self.y_min * (beta + beta_prime), dtype=dF_dp.dtype, device=dF_dp.device)

        # dF_dpの外れ値を除去
        dF_dp = dF_dp - torch.min(dF_dp) + (max - min)/2
        mean_dF_dp = torch.mean(dF_dp)
        std_dF_dp = torch.std(dF_dp)
        print("mean_dF_dp: ", mean_dF_dp, ", std_dF_dp: ", std_dF_dp)
        print("max of dF_dp: ", torch.max(dF_dp).item(), ", min of dF_dp: ", torch.min(dF_dp).item())
        
        print("max: ", max, ", min: ", min)
        
        dF_dp = torch.where(dF_dp > max, max, dF_dp)
        dF_dp = torch.where(dF_dp < min, min, dF_dp)

        index_dfdp = (dF_dp.squeeze(1) > mean_dF_dp.repeat(len(dF_dp)) - 3 * std_dF_dp.repeat(len(dF_dp)))\
                    & (dF_dp.squeeze(1) < mean_dF_dp.repeat(len(dF_dp)) + 3 * std_dF_dp.repeat(len(dF_dp)))\

        if idx_loop == 1:
            ############### change here ###############
            # weight = 1 / (3 * beta)
            weight = 1 / (beta + 2 * beta_prime)
            self.y_train = weight * dF_dp # .T??
            for _ in range(10):
                i = np.random.randint(0, len(self.y_train))
                print(f"min of ytrain: {torch.min(self.y_train)}")
                print("x: ", self.x_train[i], ", y_train: ", self.y_train[i])

            # self.y_train = self.y_train[index_dfdp]
            # self.x_train = self.x_train[index_dfdp]
        else:
            # 平均化された functional derivative を学習する
            ############### change here ###############
            # old_weight =  (idx_loop / (idx_loop + 2))
            old_weight = (beta * (idx_loop - 1) * (idx_loop) + 2 * beta_prime * (idx_loop)) / (beta * (idx_loop) * (idx_loop + 1) + 2 * beta_prime * (idx_loop + 1))
            with torch.no_grad():
                old = old_weight * previousPotential(self.x_train).detach().clone()
            # new_weight = ((2 * idx_loop) / (beta * (idx_loop + 2) * (idx_loop + 1)))
            new_weight = (2 * idx_loop) / (beta * (idx_loop) * (idx_loop + 1) + 2 * beta_prime * (idx_loop + 1))
            new = new_weight * dF_dp
            print("old: ", old_weight, ", new: ", new_weight)
            print("old: ", old[:10])
            print("new: ", new[:10])
            for _ in range(10):
                i = np.random.randint(0, len(old))
                print("x: ", self.x_train[i], ", old: ", old[i], ", new: ", new[i])
            if old.shape != new.shape:
                raise ValueError("old and new should have the same shape")
            self.y_train = old + new
            for _ in range(10):
                i = np.random.randint(0, len(self.y_train))
                print("x: ", self.x_train[i], ", y_train: ", self.y_train[i])
            
            # self.y_train = self.y_train[index_dfdp]
            # self.x_train = self.x_train[index_dfdp]

        ## here, we free previousPotential?

        if self.y_train.shape[0] != self.x_train.shape[0]:
            raise ValueError("x_train and y_train should have the same shape in dim 0")

        assert self.x_train.shape[0] > 0, "x_train should have at least one element"

        # y_trainの最小値を0にする
        y_min_tensor = torch.ones_like(self.y_train) * self.y_min
        y_train_min_tensor = (torch.ones_like(self.y_train) * torch.min(self.y_train).item()).to(self.device)
        self.y_train = self.y_train - y_train_min_tensor + y_min_tensor.to(self.device)

        del y_min_tensor, y_train_min_tensor
        gc.collect()
        torch.cuda.empty_cache()

        # y_trainの平均を取得
        mean_y_train = torch.mean(self.y_train)
        # y_trainの標準偏差を取得
        std_y_train = torch.std(self.y_train)
        # y_trainの平均から外れすぎた値を除去
        # index = (self.y_train.squeeze(1) > mean_y_train.repeat(len(self.y_train)) - 2 * std_y_train.repeat(len(self.y_train)))\
        #        & (self.y_train.squeeze(1) < mean_y_train.repeat(len(self.y_train)) + 2 * std_y_train.repeat(len(self.y_train)))\
        # self.x_train = self.x_train[index]
        # self.y_train = self.y_train[index]

        assert self.x_train.shape[0] > 0, "x_train should have at least one element"

        # x_train, y_trainのヒートマップ
        # x_trainが座標, y_trainがポテンシャルの値
        if plot_data and savedirname is not None:
            self._plot_data(savedirname, idx_loop)

    def _plot_data(self, savedirname, idx_loop):
        import matplotlib.pyplot as plt
        plt.figure()
        plt.scatter(self.x_train[:, 0].cpu().detach().numpy(), self.x_train[:, 1].cpu().detach().numpy(), c=self.y_train.squeeze(1).cpu().detach().numpy())
        plt.colorbar()
        filepath = os.path.join(savedirname, "idx_" + str(idx_loop) + "x_train_y_train_heatmap.png")
        plt.savefig(filepath)

    def _plot_potential(self, potential, idx_loop, savedirname):
        import matplotlib.pyplot as plt
        plt.figure()
        x = torch.linspace(-5, 5, 30).to(self.device)
        y = torch.linspace(-5, 5, 30).to(self.device)
        X, Y = torch.meshgrid(x, y)
        X = X.to(self.device)
        Y = Y.to(self.device)
        Z = torch.zeros_like(X).to(self.device)
        # for i in range(len(X)):
        #     for j in range(len(Y)):
        #         Z[i, j] = potential(torch.tensor([X[i, j], Y[i, j]]).to(self.device))
        # 上記は、以下のようにも書ける
        Z = torch.stack([X, Y], dim=2).view(-1, 2).to(self.device)
        Z = potential(Z).view(X.shape)
        plt.contourf(X.cpu().detach().numpy(), Y.cpu().detach().numpy(), Z.cpu().detach().numpy(), levels=100, cmap='jet')
        plt.colorbar()
        # 縦横比を1:1に
        plt.gca().set_aspect('equal', adjustable='box')
        now = datetime.datetime.now()
        now_str = now.strftime('%Y%m%d_%H%M%S')
        filepath = os.path.join(savedirname, "idx_" + str(idx_loop) + "_potential_heatmap_" + now_str + ".png")
        plt.savefig(filepath)
        ## 元のdistと合わせた結果を表示

    def train_potential(self, previousPotential, model_sbm, idx_loop, savedirname=None, plot_data=False):
        set_seed(idx_loop)

        device = self.device
        model_potential = self.model_potential
        self.objective_class = objective_like_dpo(device)  # 目的関数
        self.gamma = self.objective_class.gamma

        self._prepare_data(self.objective_class, previousPotential, model_sbm, idx_loop, savedirname, plot_data)

        lr = self.lr
        num_epochs = self.num_epochs

        x_train = self.x_train
        y_train = self.y_train

        # オプティマイザ、損失関数の定義
        # optimizer_potential = torch.optim.Adam(model_potential.parameters(), lr=lr)
        # criterion_potential = nn.MSELoss()

        # 途中でnanが出た場合にエラーを出す
        torch.autograd.set_detect_anomaly(True)

        ## 学習 ##

        # Step 1: Create Dataset
        dataset = TensorDataset(x_train, y_train)

        del x_train, y_train
        gc.collect()
        torch.cuda.empty_cache()

        # Step 2: Create DataLoader
        batch_size = 100  # Adjust based on your memory constraints
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

        del dataset
        gc.collect()
        torch.cuda.empty_cache()

        # Step 3: Training loop
        model_potential.train()  # Assuming model_potential is your model
        optimizer_potential = torch.optim.Adam(model_potential.parameters(), lr=lr)
        criterion_potential = torch.nn.MSELoss()  # Example criterion

        losses = []
        ############
        num_epochs = 100
        ############
        for epoch in range(num_epochs):
            # print("\n epoch: ", epoch)
            for batch_idx, (x_batch, y_batch) in enumerate(dataloader):
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)

                # print("batch_idx: ", batch_idx, end=" ")
                optimizer_potential.zero_grad()  # Clear the previous gradients
                
                # Forward pass
                y_pred = model_potential(x_batch)  # Forward pass
                
                # Compute the loss
                loss = criterion_potential(y_pred, y_batch)
                
                # Check for NaN or inf values in the loss
                if torch.isnan(loss).any() or torch.isinf(loss).any():
                    print("Found NaN or inf in loss. Debugging...")
                
                # Backward pass (calculate gradients)
                loss.backward()  # Compute gradients
                
                # Update model parameters
                optimizer_potential.step()
                
                losses.append(loss.item())  # Track loss

        if False:
            for epoch in range(num_epochs): # tqdm(range(num_epochs)):
                # x_trainをbatch_sizeごとに分割
                batch_size = 100
                loss_epoch = 0
                for i in range(0, len(x_train), batch_size):
                    optimizer_potential.zero_grad()
                    x_batch = x_train[i:i+batch_size]
                    y_batch = y_train[i:i+batch_size]
                    y_pred = model_potential(x_batch)
                    loss = criterion_potential(y_pred, y_batch)
                    loss.backward()
                    optimizer_potential.step()
                    loss_epoch += loss.item()
                # y_pred = model_potential(x_train)
                # loss = criterion_potential(y_pred, y_train)
                ##
                if torch.isnan(loss).any() or torch.isinf(loss).any():
                    print("Found NaN or inf in loss. Debugging...")
                
                # loss.backward(retain_graph=True) # 本当はretain_graphはいらないはず
                # loss.backward(retain_graph=True)
                # optimizer_potential.step()
                # losses.append(loss.item())
                losses.append(loss_epoch / (len(x_train) / batch_size))

        # potentialのプロット
        if plot_data and savedirname is not None:
            self._plot_potential(model_potential, idx_loop, savedirname)

        ## セーブ ##
        # 保存先ディレクトリの決定・作成

        # モデルの状態を保存
        if savedirname is not None:
            torch.save(model_potential.state_dict(), os.path.join(savedirname, 'model_potential_'+str(idx_loop)+'.pth'))

        return model_potential
