import unittest
import torch
from chunk_attn import Attention
import random
import math
import time
#torch.set_num_threads(1)
print(f'interop_threads:{torch.get_num_interop_threads()} intraop_threads:{torch.get_num_threads()}')


class TestAttn(unittest.TestCase):
    def setUp(self) -> None:
        self.device = torch.device('cpu')
        self.dtype = torch.float32
        if torch.cuda.is_available():
            self.device = torch.randn(1, device='cuda').device
            self.dtype = torch.float16
        torch.set_default_dtype(self.dtype)
        torch.set_default_device(self.device)

    def tearDown(self) -> None:
        pass

    @unittest.skip("skip")
    def test_forward(self):
        n_head, d_embed = 2, 4
        attn = Attention(n_head=n_head, d_embed=d_embed, chunk_size=2)
        attn.add_prompt(tokens=[1, 2, 3, 4, 5],
                        k=torch.ones((n_head, 5, d_embed)),
                        v=torch.ones((n_head, 5, d_embed)))
        attn.forward(q=torch.ones((n_head, 1, d_embed)))
    
    @unittest.skip("skip")
    def test_add_prompt(self):
        n_head, d_embed = 2, 4
        attn = Attention(n_head=n_head, d_embed=d_embed, chunk_size=2,
                         dtype=torch.get_default_dtype(), device=torch.randn(1).device)
        attn.add_prompt(tokens=[1, 2, 3, 4, 5],
                        k=torch.ones((n_head, 5, d_embed)),
                        v=torch.ones((n_head, 5, d_embed)))
        attn.add_prompt(tokens=[1, 2, 3, 4, 6],
                        k=torch.ones((n_head, 5, d_embed)),
                        v=torch.ones((n_head, 5, d_embed)))
        attn.add_prompt(tokens=[1, 2, 3, 4],
                        k=torch.ones((n_head, 1, d_embed)),
                        v=torch.ones((n_head, 1, d_embed)))
        attn.add_prompt(tokens=[1, 7],
                        k=torch.ones((n_head, 2, d_embed)),
                        v=torch.ones((n_head, 2, d_embed)))
        attn.add_prompt(tokens=[1, 2],
                        k=torch.ones((n_head, 2, d_embed)),
                        v=torch.ones((n_head, 2, d_embed)))
    
    @unittest.skip("skip")
    def test_add_prompt(self):
        n_head, d_embed = 2, 4
        k1 = torch.randn((n_head, 5, d_embed))
        v1 = torch.randn((n_head, 5, d_embed))
        k2 = torch.randn((n_head, 6, d_embed))
        k2[:, 0:4, :] = k1[:, 0:4, :]
        v2 = torch.randn((n_head, 6, d_embed))
        v2[:, 0:4, :] = v1[:, 0:4, :]
        q = torch.randn((n_head, 2, d_embed))
        
        score1 = torch.matmul(q[:, 0:1, :], k1.transpose(-1, -2))
        score1 = torch.softmax(score1, dim=-1)
        output1 = torch.matmul(score1, v1)
        
        score2 = torch.matmul(q[:, 1:2, :], k2.transpose(-1, -2))
        score2 = torch.softmax(score2, dim=-1)
        output2 = torch.matmul(score2, v2)
        
        output = torch.cat([output1, output2], dim=1)
        print(output)
        
        attn = Attention(n_head=n_head, d_embed=d_embed, chunk_size=2,
                         dtype=torch.get_default_dtype(), device=torch.randn(1).device)
        attn.add_prompt(tokens=[1, 2, 3, 4, 5],
                        k=k1, v=v1)
        attn.add_prompt(tokens=[1, 2, 3, 4, 6, 7],
                        k=k2, v=v2)
        attn.print()
        output = attn.forward(q=q)
        print(output)
    
    @unittest.skip("skip")
    def test_append_completions(self):
        n_head, d_embed = 2, 4        
        attn = Attention(n_head=n_head, d_embed=d_embed, chunk_size=2)
        attn.add_prompt(tokens=[1, 2, 3, 4, 5],
                        k=torch.randn((n_head, 5, d_embed)), 
                        v=torch.randn((n_head, 5, d_embed)))
        attn.add_prompt(tokens=[1, 2, 3, 4, 6, 7],
                        k=torch.randn((n_head, 6, d_embed)), 
                        v=torch.randn((n_head, 6, d_embed)))
        attn.print()
        attn.append_completions(tokens=[8, 9],
                                k=torch.randn((n_head, 2, d_embed)), 
                                v=torch.randn((n_head, 2, d_embed)))
        attn.print()
    
    @unittest.skip("skip")
    def test_duplicate(self):
        n_head, d_embed = 2, 4        
        attn = Attention(n_head=n_head, d_embed=d_embed, chunk_size=2,
                         dtype=torch.get_default_dtype(), device=torch.randn(1).device)
        attn.add_prompt(tokens=[1, 2, 3, 4, 5],
                        k=torch.randn((n_head, 5, d_embed)), 
                        v=torch.randn((n_head, 5, d_embed)))
        attn.add_prompt(tokens=[1, 2, 3, 4, 6, 7],
                        k=torch.randn((n_head, 6, d_embed)), 
                        v=torch.randn((n_head, 6, d_embed)))
        attn.print()
        attn.duplicate(0, 2)
        attn.print()
        attn.duplicate(3, 1)
        attn.print()
        attn.remove(2)
        attn.print()

    #@unittest.skip("skip")
    def test_check_result_with_pytorch(self):
        seq_len = 8192
        n_shared = round(seq_len * 0.5)
        n_head, d_embed = 32, 128
        n_requests = 32
        print(f'\nseq_len: {seq_len}, n_shared: {n_shared}, n_requests: {n_requests}')
        print(f'{torch.randn(1).device} {torch.randn(1).dtype}')

        keys = [torch.randn((n_head, seq_len, d_embed)) for _ in range(n_requests)]
        shared_keys = torch.randn((n_head, n_shared, d_embed))
        for key in keys:
            key[:, :n_shared, :] = shared_keys
        values = [torch.randn((n_head, seq_len, d_embed)) for _ in range(n_requests)]
        shared_values = torch.randn((n_head, n_shared, d_embed))
        for value in values:
            value[:, :n_shared, :] = shared_values     
        qs = [torch.randn((n_head, 1, d_embed)) for _ in range(n_requests)]        
        
        # Implemented in PyTorch
        outputs = []
        for i in range(n_requests):
            score = torch.matmul(qs[i], keys[i].transpose(-1, -2))
            score = torch.softmax(score.to(torch.float32) / math.sqrt(d_embed), dim=-1)
            score = score.to(torch.float16) if torch.cuda.is_available() else score
            outputs.append(torch.matmul(score, values[i]))
        output_ref = torch.cat(outputs, dim=1)
      
        q=torch.cat(qs, dim=1)
        chunks = [64] if torch.cuda.is_available() else [32, 64, 128, 256, 512, 1024]
        for chunk_size in chunks:
            print(f'chunk_size: {chunk_size}')
            attn = Attention(n_head=n_head, d_embed=d_embed, chunk_size=chunk_size, memory_mb=4096, 
                             dtype=self.dtype, device=self.device)
            for i in range(n_requests):
                attn.add_prompt(tokens=list(range(n_shared)) + [random.randint(n_shared, seq_len) for _ in range(seq_len - n_shared)],
                                k=keys[i], v=values[i])
            #attn.print()
            output2 = attn.forward(q=q) # chunk+seq by default
            #print(output2[0][0])
            #print(output_ref[0][0])
            self.assertTrue(torch.allclose(output_ref, output2, atol=1e-3))
            
            # seq-first only, for perf testing purpose
            output3 = attn.forward(q=q, partition=2) 
            self.assertTrue(torch.allclose(output_ref, output3, atol=1e-3))

        
if __name__ == '__main__':
    torch.set_default_device('cpu') # {cpu, cuda}
    torch.set_default_dtype(torch.float32)
    unittest.main()