import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
# import wandb
import datetime
import json, gc

import sys
sys.path.append("/home//work/doob_apps/hug")

from src.utils.img_func import show_images, make_grid, preprocess, transform
from src.finetune.obj import objective_dpo
from src.models.model_potential import ModelPotential
from src.utils.set_seed import set_seed
from src.models.CT_model_predictor import RotationPredictorCNN

import wandb

class Inner_Loop:
    def __init__(self, beta = 1, device = "cuda", mode = "butterfly", device_ids = [0, 1, 2, 3]):
        torch.autograd.set_detect_anomaly(True)
        self.mode = mode
        self.device = device
        # self.idx_loop = idx_loop # 平均化されたポテンシャルを学習する
        if mode == "butterfly":
            self.model_potential = ModelPotential()
        elif mode == "CT":
            self.model_potential = RotationPredictorCNN()
        self.model_potential = self.model_potential.to(device) # 新しく学習するポテンシャル関数
        # DataParallelを使用して並列化
        if torch.cuda.device_count() > 1:
            self.model_potential = torch.nn.DataParallel(self.model_potential, device_ids=device_ids, output_device=device)
        # self.previousPotential = previousPotential # 既に学習済みのポテンシャル関数
        self.beta = beta
        self.y_min = -2.0
        self.y_max = 2.0
        self.n_samples_from_ref = 64
        self._load_config()
        self._load_data()

    def _load_config(self):
        if self.mode == "butterfly":
            with open('/home//work/doob_apps/hug/configs/train_potential.json') as f:
                config_potential = json.load(f)
            with open('/home//work/doob_apps/hug/configs/regularization.json') as f:
                config_reg = json.load(f)
            config_path = "/home//work/doob_apps/hug/configs/configs.json"
            with open(config_path, "r") as f:
                config = json.load(f)
        elif self.mode == "CT":
            with open('/home//work/doob_apps/hug/configs/train_potential_CT.json') as f:
                config_potential = json.load(f)
            with open('/home//work/doob_apps/hug/configs/regularization_CT.json') as f:
                config_reg = json.load(f)
            config_path = "/home//work/doob_apps/hug/configs/configs_CT.json"
            with open(config_path, "r") as f:
                config = json.load(f)
        
        self.image_size = config["image_size"]
        self.image_ref_path = config["image_ref_path"]
        self.n_samples_from_ref = config_potential["num_samples"] # 10000
        self.lr = config_potential["lr"] # 0.0001
        self.num_epochs = config_potential["num_epochs"] # 1000
        self.y_min = config_reg["y_min"] # 0.0
        self.y_max = config_reg["y_max"] # 2.0
        self.gamma = config["gamma"]

    def _load_data(self):
        self.images_ref = torch.load(self.image_ref_path) # torch.randn(self.n_samples_from_ref, 3, self.image_size, self.image_size).to(self.device)
        # ランダムにサンプリング
        self.images_ref = self.images_ref[torch.randperm(self.images_ref.size()[0])]
        self.images_ref = self.images_ref[:self.n_samples_from_ref]

    def _prepare_data(self, objective_class, previousPotential, idx_loop, savedirname):
        print("Preparing data...")

        device = self.device
        n_samples_from_ref = self.n_samples_from_ref
        beta = self.beta

        previousPotential.eval()

        # Step 3: Combine x and y into a single tensor of shape (N, 2), where N is the number of points
        self.x_train = self.images_ref.detach().to(device)
        print("x_train: ", self.x_train.shape)

        # self.x_train = torch.rand(num_samples, 2).to(device) * 10 - 5

        # self.x_trainをmini_batch_size個ずつに分割し, dF_dpを計算し, 最後に結合する
        # dF_dpの計算
        mini_batch_size = self.x_train.shape[0] // 16
        for i in range(0, len(self.x_train), mini_batch_size):
            print("#"*40)
            print("generating mini batch... i: ", i)
            wandb.log({"generating mini batch..., idx_loop:" + str(idx_loop): i})
            print("#"*40)
            # x_trainを32個ずつに分割
            if i + mini_batch_size > len(self.x_train):
                x_train_i = self.x_train[i:]
            else:
                x_train_i = self.x_train[i:i+mini_batch_size]
            print("x_train: ", x_train_i.shape)
            # dF_dpを計算
            with torch.no_grad():
                dF_dp_i = objective_class.potential(x_train_i, previousPotential).detach().clone()
            # dF_dpを結合
            if i == 0:
                dF_dp = dF_dp_i
            else:
                dF_dp = torch.cat([dF_dp, dF_dp_i], dim=0)
                del dF_dp_i
                gc.collect()
                torch.cuda.empty_cache()

        print("dF_dp_all: ", dF_dp.shape)
        print("x_train: ", self.x_train.shape)

        ## dF_dpがy_maxを超える場合はy_maxにする
        if idx_loop == 1:
            y_max_for_1 = 1
            max = torch.tensor(self.y_max * beta * y_max_for_1, dtype=dF_dp.dtype, device=dF_dp.device)
            min = torch.tensor(self.y_min * beta * y_max_for_1, dtype=dF_dp.dtype, device=dF_dp.device)
        else:
            max = torch.tensor(self.y_max * beta, dtype=dF_dp.dtype, device=dF_dp.device)
            min = torch.tensor(self.y_min * beta, dtype=dF_dp.dtype, device=dF_dp.device)

        # dF_dpの外れ値を除去
        dF_dp = dF_dp - torch.mean(dF_dp)
        mean_dF_dp = torch.mean(dF_dp)
        std_dF_dp = torch.std(dF_dp)
        print("mean_dF_dp: ", mean_dF_dp, ", std_dF_dp: ", std_dF_dp)
        print("max of dF_dp: ", torch.max(dF_dp).item(), ", min of dF_dp: ", torch.min(dF_dp).item())
        
        print("max: ", max, ", min: ", min)
        
        dF_dp = torch.where(dF_dp > max, max, dF_dp)
        dF_dp = torch.where(dF_dp < min, min, dF_dp)

        index_dfdp = (dF_dp > mean_dF_dp.repeat(len(dF_dp)) - 3 * std_dF_dp.repeat(len(dF_dp)))\
                    & (dF_dp < mean_dF_dp.repeat(len(dF_dp)) + 3 * std_dF_dp.repeat(len(dF_dp)))\

        if idx_loop == 1:
            weight = 1 / (3 * beta)
            self.y_train = weight * dF_dp # .T??
            # for _ in range(10):
            #     i = np.random.randint(0, len(self.y_train))
            #     print("x: ", self.x_train[i], ", y_train: ", self.y_train[i])

            # self.y_train = self.y_train[index_dfdp]
            # self.x_train = self.x_train[index_dfdp]
        else:
            # 平均化された functional derivative を学習する
            old_weight =  (idx_loop / (idx_loop + 2))
            ## self.x_trainをmini_batch_size個ずつに分割し, previousPotentialを計算し, 最後に結合する
            mini_batch_size = self.x_train.shape[0] // 16
            for i in range(0, len(self.x_train), mini_batch_size):
                print("#"*40)
                print("generating mini batch... i: ", i)
                wandb.log({"generating mini batch..., idx_loop:" + str(idx_loop): i})
                print("#"*40)
                # x_trainを32個ずつに分割
                if i + mini_batch_size > len(self.x_train):
                    x_train_i = self.x_train[i:]
                else:
                    x_train_i = self.x_train[i:i+mini_batch_size]
                print("x_train_i: ", x_train_i.shape)
                # previousPotentialを計算
                previousPotential.eval()
                with torch.no_grad():
                    previousPotential_i = previousPotential(x_train_i).clone().detach()
                # previousPotential_iを結合
                if i == 0:
                    previousPotential_all = previousPotential_i
                else:
                    previousPotential_all = torch.cat([previousPotential_all, previousPotential_i], dim=0)
                    del previousPotential_i
                    gc.collect()
                    torch.cuda.empty_cache()
            if self.mode == "CT":
                previousPotential_all = previousPotential_all.squeeze(1)
            print("previousPotential_all: ", previousPotential_all.shape)
            old = old_weight * previousPotential_all.detach()
            new_weight = ((2 * idx_loop) / (beta * (idx_loop + 2) * (idx_loop + 1)))
            new = new_weight * dF_dp
            print("old: ", old_weight, ", new: ", new_weight)
            print("old: ", old[:10])
            print("new: ", new[:10])
            for _ in range(10):
                i = np.random.randint(0, len(old))
                # print("x: ", self.x_train[i], ", old: ", old[i], ", new: ", new[i])
            if old.shape != new.shape:
                raise ValueError("old and new should have the same shape")
            self.y_train = old + new
            # self.y_train = self.y_train[index_dfdp]
            # self.x_train = self.x_train[index_dfdp]
            del previousPotential_all, old, new
            gc.collect()
            torch.cuda.empty_cache()

        if self.y_train.shape[0] != self.x_train.shape[0]:
            raise ValueError("x_train and y_train should have the same shape in dim 0")

        assert self.x_train.shape[0] > 0, "x_train should have at least one element"

        # y_trainの最小値を0にする
        self.y_train = self.y_train - torch.min(self.y_train)
        # y_trainの平均を取得
        mean_y_train = torch.mean(self.y_train)
        # y_trainの標準偏差を取得
        std_y_train = torch.std(self.y_train)
        # y_trainの平均から外れすぎた値を除去
        # index = (self.y_train.squeeze(1) > mean_y_train.repeat(len(self.y_train)) - 2 * std_y_train.repeat(len(self.y_train)))\
        #        & (self.y_train.squeeze(1) < mean_y_train.repeat(len(self.y_train)) + 2 * std_y_train.repeat(len(self.y_train)))\
        # self.x_train = self.x_train[index]
        # self.y_train = self.y_train[index]

        assert self.x_train.shape[0] > 0, "x_train should have at least one element"

        # y_trainの平均と標準偏差を取得
        mean_y_train = torch.mean(self.y_train)
        std_y_train = torch.std(self.y_train)
        print("mean_y_train: ", mean_y_train, ", std_y_train: ", std_y_train)


        print("Data preparation is done.")

    def train_potential(self, previousPotential, idx_loop, savedirname=None, plot_data=False):
        set_seed(idx_loop)

        print("Training potential...")
        print("idx_loop: ", idx_loop)
        print("numepochs: ", self.num_epochs)
        print("lr: ", self.lr)

        device = self.device
        model_potential = self.model_potential
        self.objective_class = objective_dpo(device=device, mode=self.mode)  # 目的関数
        self.gamma = self.objective_class.gamma

        self._prepare_data(self.objective_class, previousPotential, idx_loop, savedirname)

        lr = self.lr
        num_epochs = self.num_epochs

        x_train = self.x_train
        y_train = self.y_train

        # x_train, y_trainから, データセットを作成
        dataset = torch.utils.data.TensorDataset(x_train, y_train)
        # データローダーを作成
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

        del x_train, y_train
        gc.collect()
        torch.cuda.empty_cache()

        # オプティマイザ、損失関数の定義
        optimizer_potential = torch.optim.Adam(model_potential.parameters(), lr=lr)
        criterion_potential = nn.MSELoss()

        # 途中でnanが出た場合にエラーを出す
        torch.autograd.set_detect_anomaly(True)

        ## 学習 ##
        losses = []
        for epoch in range(num_epochs): # tqdm(range(num_epochs)):
            model_potential.train()
            running_loss = 0.0
            data_count = 0
            for i, (x_train_batch, y_train_batch) in enumerate(dataloader):
                x_train_batch = x_train_batch.to(device)
                y_train_batch = y_train_batch.to(device)

                x_train_batch.detach()
                if i == 0:
                    print("x_train_batch: ", x_train_batch.shape)
                    print("y_train_batch: ", y_train_batch.shape)
                # x_train_batch.requires_grad = True
                optimizer_potential.zero_grad()
                y_pred = model_potential(x_train_batch)
                if self.mode == "CT":
                    y_pred = y_pred.squeeze(1)
                if i == 0:
                    print("y_pred: ", y_pred.shape)
                loss = criterion_potential(y_pred, y_train_batch)

                ##
                if torch.isnan(loss).any() or torch.isinf(loss).any():
                    print("Found NaN or inf in loss. Debugging...")
                
                loss.backward() # retain_graph=True) # 本当はretain_graphはいらないはず
                optimizer_potential.step()

                running_loss += loss.item()
                data_count += x_train_batch.size(0)

            losses.append(running_loss)
            print("epoch: ", epoch, ", running_loss: ", running_loss / data_count, "\n")
            wandb.log({"epoch": epoch})
            wandb.log({"loss, idx_loop:"+str(idx_loop) : running_loss / data_count})

        print("Finished training potential.")
        del dataset, dataloader, optimizer_potential, criterion_potential
        gc.collect()
        torch.cuda.empty_cache()

        ## セーブ ##
        # 保存先ディレクトリの決定・作成

        # モデルの状態を保存
        if savedirname is not None:
            torch.save(model_potential.state_dict(), os.path.join(savedirname, 'model_potential_'+str(idx_loop)+'.pth'))
        return model_potential