import time
import unittest
from pathlib import Path

import torch

from .test_hip import print_errors
from quick_extend.utils.memory_efficient_llm_ce import memory_efficient_llm_ce


class TestCEBackward(unittest.TestCase):
    def test_efficent_ce(self):
        #HID = 4096
        #NUM_VOCAB = 32001
        #N = 4096 * 20

        HID = 512
        NUM_VOCAB = 1024
        N = 128

        dtype = torch.float32
        #dtype = torch.bfloat16

        torch.random.manual_seed(0)
        hidden = torch.randn((N, HID), dtype=dtype, device=0)
        labels = torch.randint(0, NUM_VOCAB, (N,), dtype=torch.long, device=0)
        weight = torch.randn((NUM_VOCAB, HID), dtype=dtype, device=0)
        out_noise = torch.tensor(3.14, dtype=dtype, device=0)
        multiplier = 2.0

        labels[torch.randint(0, 3, labels.shape) == 0] = -100

        hidden.requires_grad = True
        weight.requires_grad = True

        do_grad = True

        print_output = False
        warmups = 1
        n_repeats = 3

        t = time.time()
        for i in range(n_repeats + warmups):
            if i == warmups:
                torch.cuda.synchronize()
                torch.cuda.reset_peak_memory_stats(0)
                start_mem = torch.cuda.max_memory_allocated(0)
                start_time = torch.cuda.Event(enable_timing=True)
                end_time = torch.cuda.Event(enable_timing=True)
                start_time.record()
            logits = torch.nn.functional.linear(
                hidden, weight, None,
            ) * multiplier
            losses_torch = torch.nn.CrossEntropyLoss(reduction='mean')(logits, labels)
            if do_grad:
                losses_torch.backward(out_noise)
                grad_hidden_triton = hidden.grad
                grad_weight_triton = weight.grad
                hidden.grad = None
                weight.grad = None
        end_time.record()
        torch.cuda.synchronize()
        peak_mem_torch = torch.cuda.max_memory_allocated(0) - start_mem
        elapsed_torch = start_time.elapsed_time(end_time)

        t = time.time()
        for i in range(n_repeats + warmups):
            if i == warmups:
                torch.cuda.synchronize()
                torch.cuda.reset_peak_memory_stats(0)
                start_mem = torch.cuda.max_memory_allocated(0)
                start_time = torch.cuda.Event(enable_timing=True)
                end_time = torch.cuda.Event(enable_timing=True)
                start_time.record()
            losses_triton = memory_efficient_llm_ce(
                hidden, weight, labels, multiplier, 'mean',
                threshold=0
            )
            if do_grad:
                losses_triton.backward(out_noise)
                grad_hidden_torch = hidden.grad
                grad_weight_torch = weight.grad
                hidden.grad = None
                weight.grad = None
        end_time.record()
        torch.cuda.synchronize()
        peak_mem_triton = torch.cuda.max_memory_allocated(0) - start_mem
        elapsed_triton = start_time.elapsed_time(end_time)

        print(f"Time: torch={elapsed_torch / n_repeats:.2f}ms, triton={elapsed_triton / n_repeats:.2f}ms")
        print(f"Peak memory: {peak_mem_torch / 10 ** 6:.3f}MB, {peak_mem_triton / 10 ** 6:.3f}MB")

        print(losses_torch.float(), losses_triton.float())

        if print_output:
            output_dir = Path("output")
            output_dir.mkdir(exist_ok=True)

            import numpy as np
            np.set_printoptions(threshold=np.inf, linewidth=200)

        atol, rtol = 1e-3, 5e-3
        print("Errors: ")
        print_errors(losses_torch, losses_triton)
        assert torch.allclose(losses_torch.float(), losses_triton.float(), atol=atol, rtol=rtol), print_errors(losses_torch, losses_triton)
        if do_grad:
            if print_output:
                with open(output_dir / "grad_hidden_torch.txt", "w") as f:
                    f.write(str(grad_hidden_torch.float().cpu().numpy()))
                with open(output_dir / "grad_hidden_triton.txt", "w") as f:
                    f.write(str(grad_hidden_triton.float().cpu().numpy()))
                with open(output_dir / "grad_weight_torch.txt", "w") as f:
                    f.write(str(grad_weight_torch.float().cpu().numpy()))
                with open(output_dir / "grad_weight_triton.txt", "w") as f:
                    f.write(str(grad_weight_triton.float().cpu().numpy()))
            print("Grad errors: ")
            print("- Hidden grad: ")
            print_errors(grad_hidden_torch, grad_hidden_triton)
            print("- Weight grad: ")
            print_errors(grad_weight_torch, grad_weight_triton)
            assert torch.allclose(grad_hidden_torch.float(), grad_hidden_triton.float(), atol=atol, rtol=rtol), print_errors(grad_hidden_torch, grad_hidden_triton)
            assert torch.allclose(grad_weight_torch.float(), grad_weight_triton.float(), atol=atol, rtol=rtol), print_errors(grad_weight_torch, grad_weight_triton)


if __name__ == '__main__':
    unittest.main()
