import os

import warp as wp
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

from simulator import MPMSimulator

class GradientOptimizer:
    def __init__(self, simulator: MPMSimulator,optim_path = './optim_stage2'):
        self.simulator = simulator
        self.iteration = 50
        self.push_idx = 0
        self.simulator.init_scene(self.push_idx)
        self.gripper_type = self.simulator.gripper_type
        self.material = self.simulator.material
        self.save_root = self.simulator.output_path
        self.optim_save_root = optim_path
        self.optimizer = None
        self.scheduler = None

        E_torch = wp.to_torch(self.simulator.mpm_solver.mpm_model.E).clone().detach()
        self.log_E = torch.log(E_torch).requires_grad_(True)

        E_cloth_torch = wp.to_torch(self.simulator.mpm_solver.mpm_model.E_cloth).clone().detach()
        warp_stiffness_torch = E_cloth_torch[:, 0]
        weft_stiffness_torch = E_cloth_torch[:, 1]
        shear_stiffness_torch = E_cloth_torch[:, 2]
        self.log_warp_stiffness = torch.log(warp_stiffness_torch).requires_grad_(True)
        self.log_weft_stiffness = torch.log(weft_stiffness_torch).requires_grad_(True)
        self.log_shear_stiffness = torch.log(shear_stiffness_torch).requires_grad_(True)

        density_torch = wp.to_torch(self.simulator.mpm_solver.mpm_state.particle_density).clone().detach()
        self.density = density_torch.requires_grad_(True)
        friction_torch = wp.to_torch(self.simulator.mpm_solver.mpm_model.friction).clone().detach()
        self.friction = friction_torch.requires_grad_(True)

    def optimize_stage1(self):
        optim_save_path = os.path.join(self.optim_save_root, os.path.basename(self.simulator.data_path))
        save_path = os.path.join(self.save_root, os.path.basename(self.simulator.data_path))
        os.makedirs(optim_save_path, exist_ok=True)
        os.makedirs(save_path, exist_ok=True)
        writer = SummaryWriter(log_dir=os.path.join(save_path, 'log'))
        length = self.log_E.shape[0]
        self.log_E = self.log_E[0].detach().clone().requires_grad_(True)
        self.log_warp_stiffness = self.log_warp_stiffness[0].detach().clone().requires_grad_(True)
        self.log_weft_stiffness = self.log_weft_stiffness[0].detach().clone().requires_grad_(True)
        self.log_shear_stiffness = self.log_shear_stiffness[0].detach().clone().requires_grad_(True)

        self.set_optimizer()
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=30)

        best_loss = float('inf')
        avg_loss_list = []
        avg_loss_test_list = []
        for epoch in range(self.iteration):
            total_loss = 0.0
            total_loss_test = 0.0
            pbar = tqdm(range(self.simulator.start_frame, self.simulator.end_frame))

            self.simulator.init_solver(self.push_idx)
            for frame in pbar:
                with self.simulator.tape:
                    if self.material == 'cloth':
                        self.simulator.wp_set_phys_property_cloth_D_single(self.log_warp_stiffness, self.log_weft_stiffness, self.log_shear_stiffness, self.density)
                    else:
                        self.simulator.wp_set_phys_property_ED_single(self.log_E, self.density)
                    self.simulator.set_friction(self.friction)
                if self.gripper_type == 'single_gripper':
                    self.simulator.step_single_gripper(last_frame_grad=True)
                elif self.gripper_type == 'double_gripper':
                    self.simulator.step_double_gripper(last_frame_grad=True)
                elif self.gripper_type == 'push':
                    self.simulator.step_push(last_frame_grad=True)
                self.simulator.calculate_loss_gradient()

                if frame < self.simulator.train_frame:
                    total_loss += self.simulator.temp_loss.item()
                    self.simulator.tape.backward(self.simulator.grad_loss)
                    self.optimizer.step()
                    with torch.no_grad():
                        self.params_clamp()
                else:
                    total_loss_test += self.simulator.temp_loss.item()
                with torch.no_grad():
                    ee_primitive = self.simulator.mpm_solver.primitive.visual_3dgs if self.gripper_type == 'push' else None
                    self.simulator.get_current_state(save_path=save_path, epoch=epoch, draw_primitive=ee_primitive,draw_gt=True)
                    info = f'density: min: {self.density.min():.4f}, max: {self.density.max():.4f}; '
                    # info += f' log_E:  min: {self.log_E.min():.4f}, max: {self.log_E.max():.4f}; '
                    # info += f' log_warp_stiffness:  min: {self.log_warp_stiffness.min():.4f}, max: {self.log_warp_stiffness.max():.4f}; '
                    info += f' friction: {self.friction.item():.4f}'
                    pbar.set_description(info)
                self.optimizer.zero_grad()
                self.simulator.tape.reset()
                self.simulator.clear_loss()
                torch.cuda.empty_cache()
            avg_loss = total_loss / (self.simulator.train_frame - self.simulator.start_frame)
            avg_loss_test = total_loss_test / (self.simulator.test_frame - self.simulator.train_frame)
            avg_loss_list.append(avg_loss)
            avg_loss_test_list.append(avg_loss_test)
            # lr_step
            self.scheduler.step()
            # logging
            writer.add_scalar('Loss/train', total_loss, epoch)
            writer.add_scalar('Loss/test', total_loss_test, epoch)
            writer.add_scalar('LR', self.optimizer.param_groups[0]['lr'], epoch)
            def plot_loss_curve(loss_history, save_path, name):
                plt.figure(figsize=(10, 6))
                plt.plot(loss_history, label=f"{name} Loss")
                plt.xlabel("Epoch")
                plt.ylabel("Loss")
                plt.title(f"{name} Loss Curve")
                plt.legend()
                plt.grid(True)
                plt.savefig(os.path.join(save_path, f'{name}_loss.jpg'))
                plt.close()
            plot_loss_curve(avg_loss_list, save_path, 'Train')
            plot_loss_curve(avg_loss_test_list, save_path, 'Test')
            # save_model
            with torch.no_grad():
                if avg_loss_test < best_loss:
                    update = True
                    best_loss = avg_loss_test
                    if self.material == 'cloth':
                        log_warp_stiffness = torch.full((length,), self.log_warp_stiffness, dtype=torch.float32)
                        log_weft_stiffness = torch.full((length,), self.log_weft_stiffness, dtype=torch.float32)
                        log_shear_stiffness = torch.full((length,), self.log_shear_stiffness, dtype=torch.float32)
                        self.simulator.save_params_cloth(log_warp_stiffness, log_weft_stiffness,
                                                         log_shear_stiffness, self.density, self.friction,
                                                         os.path.join(optim_save_path, 'best_params.pkl'))
                    else:
                        log_E = torch.full((length,), self.log_E, dtype=torch.float32)
                        self.simulator.save_params(log_E, self.density, self.friction,
                                                   os.path.join(optim_save_path, 'best_params.pkl'))
                else:
                    update = False
            print(f'Epoch {epoch}: Train_Loss={total_loss:.6f}, Test_Loss={total_loss_test:.6f}, LR={self.optimizer.param_groups[0]["lr"]:.2e}, update:{update}')
        writer.close()

    def optimize_stage2(self):
        optim_save_path = os.path.join(self.optim_save_root, os.path.basename(self.simulator.data_path))
        save_path = os.path.join(self.save_root, os.path.basename(self.simulator.data_path))
        os.makedirs(optim_save_path, exist_ok=True)
        os.makedirs(save_path, exist_ok=True)
        writer = SummaryWriter(log_dir=os.path.join(save_path, 'log'))

        self.set_optimizer()
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=30)

        best_loss = float('inf')
        avg_loss_list = []
        avg_loss_test_list = []
        for epoch in range(self.iteration):
            total_loss = 0.0
            total_loss_test = 0.0
            pbar = tqdm(range(self.simulator.start_frame, self.simulator.end_frame))

            self.simulator.init_solver(self.push_idx)
            for frame in pbar:
                with self.simulator.tape:
                    if self.material == 'cloth':
                        self.simulator.wp_set_phys_property_cloth_D(self.log_warp_stiffness, self.log_weft_stiffness, self.log_shear_stiffness, self.density)
                    else:
                        self.simulator.wp_set_phys_property_ED(self.log_E, self.density)
                    self.simulator.set_friction(self.friction)
                if self.gripper_type == 'single_gripper':
                    self.simulator.step_single_gripper(last_frame_grad=True)
                elif self.gripper_type == 'double_gripper':
                    self.simulator.step_double_gripper(last_frame_grad=True)
                elif self.gripper_type == 'push':
                    self.simulator.step_push(last_frame_grad=True)
                self.simulator.calculate_loss_gradient()

                if frame < self.simulator.train_frame:
                    total_loss += self.simulator.temp_loss.item()
                    self.simulator.tape.backward(self.simulator.grad_loss)
                    self.optimizer.step()
                    with torch.no_grad():
                        self.params_clamp()
                else:
                    total_loss_test += self.simulator.temp_loss.item()
                with torch.no_grad():
                    ee_primitive = self.simulator.mpm_solver.primitive.visual_3dgs if self.gripper_type == 'push' else None
                    self.simulator.get_current_state(save_path=save_path, epoch=epoch, draw_primitive=ee_primitive,draw_gt=True)
                    info = f'density: min: {self.density.min():.4f}, max: {self.density.max():.4f}; '
                    # info += f' log_E:  min: {self.log_E.min():.4f}, max: {self.log_E.max():.4f}; '
                    # info += f' log_warp_stiffness:  min: {self.log_warp_stiffness.min():.4f}, max: {self.log_warp_stiffness.max():.4f}; '
                    info += f' friction: {self.friction.item():.4f}'
                    pbar.set_description(info)
                self.optimizer.zero_grad()
                self.simulator.tape.reset()
                self.simulator.clear_loss()
                torch.cuda.empty_cache()
            avg_loss = total_loss / (self.simulator.train_frame - self.simulator.start_frame)
            avg_loss_test = total_loss_test / (self.simulator.test_frame - self.simulator.train_frame)
            avg_loss_list.append(avg_loss)
            avg_loss_test_list.append(avg_loss_test)
            # lr_step
            self.scheduler.step()
            # logging
            writer.add_scalar('Loss/train', total_loss, epoch)
            writer.add_scalar('Loss/test', total_loss_test, epoch)
            writer.add_scalar('LR', self.optimizer.param_groups[0]['lr'], epoch)
            def plot_loss_curve(loss_history, save_path, name):
                plt.figure(figsize=(10, 6))
                plt.plot(loss_history, label=f"{name} Loss")
                plt.xlabel("Epoch")
                plt.ylabel("Loss")
                plt.title(f"{name} Loss Curve")
                plt.legend()
                plt.grid(True)
                plt.savefig(os.path.join(save_path, f'{name}_loss.jpg'))
                plt.close()
            plot_loss_curve(avg_loss_list, save_path, 'Train')
            plot_loss_curve(avg_loss_test_list, save_path, 'Test')
            # save_model
            with torch.no_grad():
                if avg_loss_test < best_loss:
                    update = True
                    best_loss = avg_loss_test
                    if self.material == 'cloth':
                        self.simulator.save_params_cloth(self.log_warp_stiffness, self.log_weft_stiffness,
                                                         self.log_shear_stiffness, self.density, self.friction,
                                                         os.path.join(optim_save_path, 'best_params.pkl'))
                    else:
                        self.simulator.save_params(self.log_E, self.density, self.friction,
                                                   os.path.join(optim_save_path, 'best_params.pkl'))
                else:
                    update = False
            print(f'Epoch {epoch}: Train_Loss={total_loss:.6f}, Test_Loss={total_loss_test:.6f}, LR={self.optimizer.param_groups[0]["lr"]:.2e}, update:{update}')
        writer.close()

    def set_optimizer(self):
        if self.material == 'cloth':
            lr_warp_stiffness = self.simulator.preprocessing_params["lr_warp_stiffness"]
            lr_weft_stiffness = self.simulator.preprocessing_params["lr_weft_stiffness"]
            lr_shear_stiffness = self.simulator.preprocessing_params["lr_shear_stiffness"]
            lr_density = self.simulator.preprocessing_params["lr_density"]
            lr_friction = self.simulator.preprocessing_params["lr_friction"]
            self.optimizer = torch.optim.SGD(
                [
                    {'params': self.log_warp_stiffness, 'lr': lr_warp_stiffness, 'weight_decay': 0.00},
                    {'params': self.log_weft_stiffness, 'lr': lr_weft_stiffness, 'weight_decay': 0.00},
                    {'params': self.log_shear_stiffness, 'lr': lr_shear_stiffness, 'weight_decay': 0.00},
                    {'params': self.density, 'lr': lr_density, 'weight_decay': 0.00},
                    {'params': self.friction, 'lr': lr_friction, 'weight_decay': 0.00},
                ])
            print(lr_warp_stiffness,lr_weft_stiffness,lr_shear_stiffness)
        else:
            lr_E = self.simulator.preprocessing_params["lr_E"]
            lr_density = self.simulator.preprocessing_params["lr_density"]
            lr_friction = self.simulator.preprocessing_params["lr_friction"]
            self.optimizer = torch.optim.SGD(
                [
                    {'params': self.log_E, 'lr': lr_E, 'weight_decay': 0.00},
                    {'params': self.density, 'lr': lr_density, 'weight_decay': 0.00},
                    {'params': self.friction, 'lr': lr_friction, 'weight_decay': 0.00},
                ])

    def params_clamp(self):
        if self.material == 'cloth':
            torch.clamp_(self.log_warp_stiffness, 4, 14)
            torch.clamp_(self.log_weft_stiffness, 4, 14)
            torch.clamp_(self.log_shear_stiffness, 4, 14)
            torch.clamp_(self.density, 20, 2000)
            torch.clamp_(self.friction, 0.01, 10)
        else:
            torch.clamp_(self.log_E, 5, 13)
            torch.clamp_(self.density, 20, 2000)
            torch.clamp_(self.friction, 0.01, 10)
